From 0bb733ba230051301b3fb3fa49d1d6662744b395 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Mon, 14 Jul 2025 13:01:16 +0000 Subject: [PATCH 001/457] Add cuda 12.4 build in CI (#157958) Fixes to https://github.com/pytorch/pytorch/issues/156747 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157958 Approved by: https://github.com/malfet, https://github.com/Skylion007 --- .ci/docker/build.sh | 11 +++++ .ci/docker/common/install_cuda.sh | 49 ++++++++++++++++++++++ .ci/docker/common/install_cudnn.sh | 2 + .ci/docker/common/install_cusparselt.sh | 8 ++++ .github/workflows/docker-builds.yml | 1 + .github/workflows/periodic.yml | 31 ++++++++++++++ test/inductor/test_aot_inductor.py | 9 +++- test/inductor/test_aot_inductor_package.py | 13 ++++++ 8 files changed, 123 insertions(+), 1 deletion(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 97e6bce3e59d0..075b5e80209fd 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,6 +91,17 @@ tag=$(echo $image | awk -F':' '{print $2}') # configuration, so we hardcode everything here rather than do it # from scratch case "$tag" in + pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11) + CUDA_VERSION=12.4 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11) CUDA_VERSION=12.8.1 CUDNN_VERSION=9 diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index cd9701e7590b5..c8a780f65c8e5 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -78,6 +78,19 @@ function install_nvshmem { echo "nvSHMEM ${nvshmem_version} for CUDA ${cuda_major_version} (${arch_path}) installed." } +function install_124 { + CUDNN_VERSION=9.1.0.70 + echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" + install_cuda 12.4.1 cuda_12.4.1_550.54.15_linux + + install_cudnn 12 $CUDNN_VERSION + + CUDA_VERSION=12.4 bash install_nccl.sh + + CUDA_VERSION=12.4 bash install_cusparselt.sh + + ldconfig +} function install_126 { CUDNN_VERSION=9.10.2.21 @@ -113,6 +126,40 @@ function install_129 { ldconfig } +function prune_124 { + echo "Pruning CUDA 12.4" + ##################################################################################### + # CUDA 12.4 prune static libs + ##################################################################################### + export NVPRUNE="/usr/local/cuda-12.4/bin/nvprune" + export CUDA_LIB_DIR="/usr/local/cuda-12.4/lib64" + + export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + + if [[ -n "$OVERRIDE_GENCODE" ]]; then + export GENCODE=$OVERRIDE_GENCODE + fi + if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then + export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN + fi + + # all CUDA libs except CuDNN and CuBLAS + ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ + | xargs -I {} bash -c \ + "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" + + # prune CuDNN and CuBLAS + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a + + ##################################################################################### + # CUDA 12.4 prune visual tools + ##################################################################################### + export CUDA_BASE="/usr/local/cuda-12.4/" + rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.1.0 $CUDA_BASE/nsight-systems-2023.4.4/ +} + function prune_126 { echo "Pruning CUDA 12.6" ##################################################################################### @@ -169,6 +216,8 @@ function install_128 { while test $# -gt 0 do case "$1" in + 12.4) install_124; prune_124 + ;; 12.6|12.6.*) install_126; prune_126 ;; 12.8|12.8.*) install_128; diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 7ee5e73226cb6..fecdb448589e1 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -8,6 +8,8 @@ if [[ -n "${CUDNN_VERSION}" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" + elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index ca29a94e58fc9..feacb49f39eb5 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -13,6 +13,14 @@ if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then fi CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz +elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then + arch_path='sbsa' + export TARGETARCH=${TARGETARCH:-$(uname -m)} + if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then + arch_path='x86_64' + fi + CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive" + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz else echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}" fi diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index c7f9b92889374..43843751eb8fd 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -57,6 +57,7 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.9-clang12, pytorch-linux-jammy-py3.11-clang12, pytorch-linux-jammy-py3.12-clang12, diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 0882019d51151..7e70f4e21d0db 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -51,6 +51,37 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + linux-jammy-cuda12_4-py3_10-gcc11-sm89-build: + name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 + cuda-arch-list: 8.9 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_4-py3_10-gcc11-sm89-test: + name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build + - target-determination + with: + build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }} + secrets: inherit + linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7950c3672cf4f..08799fd6db708 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -35,7 +35,11 @@ from torch.export.pt2_archive._package import load_pt2 from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + PLATFORM_SUPPORTS_FP8, + SM80OrLater, +) from torch.testing._internal.common_device_type import ( _has_sufficient_memory, skipCUDAIf, @@ -188,6 +192,9 @@ def forward(self, x, y): # Skip embed_kernel_binary == True for now as it shows random # failure on CI @common_utils.parametrize("embed_kernel_binary", [False]) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU_TYPE") diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2f2b92168c6ec..a607c4f33e7d3 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -21,6 +21,7 @@ from torch._inductor.utils import fresh_cache from torch.export import Dim from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents +from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_utils import ( IS_FBCODE, skipIfRocm, @@ -249,6 +250,9 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @skipIfXpu # build system may be different def test_compile_after_package(self): self.check_package_cpp_only() @@ -294,6 +298,9 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary @@ -338,6 +345,9 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different def test_compile_after_package_static(self): @@ -396,6 +406,9 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary From 86d8af6a6cc648134289de89d393d0dce5b3a5f4 Mon Sep 17 00:00:00 2001 From: Ting Lu Date: Mon, 14 Jul 2025 13:11:10 +0000 Subject: [PATCH 002/457] Add sm_70 to windows 12.9 build (#158126) Please see: https://github.com/pytorch/pytorch/issues/157517 Volta architectures will be kept for 12.8/12.9 builds for release 2.8 (12.8 win build does not need change since already including sm70) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158126 Approved by: https://github.com/Skylion007, https://github.com/atalman --- .ci/pytorch/windows/cuda129.bat | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/windows/cuda129.bat b/.ci/pytorch/windows/cuda129.bat index 77ef14921aa63..b17e6113c63e2 100644 --- a/.ci/pytorch/windows/cuda129.bat +++ b/.ci/pytorch/windows/cuda129.bat @@ -37,10 +37,10 @@ IF "%CUDA_PATH_V129%"=="" ( ) IF "%BUILD_VISION%" == "" ( - set TORCH_CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0 + set TORCH_CUDA_ARCH_LIST=7.0;7.5;8.0;8.6;9.0;10.0;12.0 set TORCH_NVCC_FLAGS=-Xfatbin -compress-all ) ELSE ( - set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 + set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 ) set "CUDA_PATH=%CUDA_PATH_V129%" From 826f12b829070e3d5bfd050f001b61aaf78e5447 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 8 Jul 2025 12:24:24 -0700 Subject: [PATCH 003/457] [SymmMem] Avoid library mismatch in CMake search (#157836) Before, if NVSHMEM is installed at *BOTH* system location (e.g. `/usr/local`) and conda location (e.g. `/path/to/conda/lib/python3.10/site-packages/nvidia/nvshmem`, there can be a mismatch in where host lib and device lib are found: ``` -- NVSHMEM_HOME set to: '' -- NVSHMEM wheel installed at: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem' -- NVSHMEM_HOST_LIB: '/usr/local/lib/libnvshmem_host.so' -- NVSHMEM_DEVICE_LIB: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_device.a' -- NVSHMEM_INCLUDE_DIR: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem/include' ``` The reason is that CMake prioritize name search over dir search. In the script below, CMake will search all locations for `libnvshmem_host.so` first, before it searches for `.so.3`. ``` find_library(NVSHMEM_HOST_LIB # In pip install case, the lib suffix is `.so.3` instead of `.so` NAMES nvshmem_host nvshmem_host.so.3 HINTS $ENV{NVSHMEM_HOME} ${NVSHMEM_PY_DIR} PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) ``` This PR adds the `NAMES_PER_DIR` flag, according to CMake's doc: > The NAMES_PER_DIR option tells this command to consider one directory at a time and search for all names in it. After this PR: ``` -- NVSHMEM_HOME set to: '' -- NVSHMEM wheel installed at: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem' -- NVSHMEM_HOST_LIB: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3' -- NVSHMEM_DEVICE_LIB: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem/lib/libnvshmem_device.a' -- NVSHMEM_INCLUDE_DIR: '.conda/envs/pytorch-3.10/lib/python3.10/site-packages/nvidia/nvshmem/include' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157836 Approved by: https://github.com/fegin, https://github.com/fduwjj ghstack dependencies: #157513, #157695 --- caffe2/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 6f6feac4b1ce3..7d0a98fbd33be 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1001,7 +1001,7 @@ elseif(USE_CUDA) # 3. Let CMake find it in the default system paths, e.g. /usr/local. find_library(NVSHMEM_HOST_LIB # In pip install case, the lib suffix is `.so.3` instead of `.so` - NAMES nvshmem_host nvshmem_host.so.3 + NAMES nvshmem_host libnvshmem_host.so.3 NAMES_PER_DIR HINTS $ENV{NVSHMEM_HOME} ${NVSHMEM_PY_DIR} PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64 DOC "The location of NVSHMEM host library.") From 59c3cac4547aafd2f718b7c64053098cc5886878 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 10 Jul 2025 14:00:00 -0300 Subject: [PATCH 004/457] Tag CPython test files with the commit or tag they were copied from. (#158038) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158038 Approved by: https://github.com/XuehaiPan, https://github.com/zou3519 ghstack dependencies: #157799, #157800, #157801, #157802, #156981 --- test/dynamo/cpython/3_13/list_tests.diff | 11 +- test/dynamo/cpython/3_13/list_tests.py | 3 + test/dynamo/cpython/3_13/mapping_tests.diff | 7 +- test/dynamo/cpython/3_13/mapping_tests.py | 3 + test/dynamo/cpython/3_13/seq_tests.diff | 9 +- test/dynamo/cpython/3_13/seq_tests.py | 3 + .../cpython/3_13/test_baseexception.diff | 13 +- .../dynamo/cpython/3_13/test_baseexception.py | 3 + test/dynamo/cpython/3_13/test_cmath.diff | 13 +- test/dynamo/cpython/3_13/test_cmath.py | 3 + test/dynamo/cpython/3_13/test_complex.diff | 13 +- test/dynamo/cpython/3_13/test_complex.py | 3 + test/dynamo/cpython/3_13/test_contextlib.diff | 35 ++-- test/dynamo/cpython/3_13/test_contextlib.py | 3 + test/dynamo/cpython/3_13/test_dict.diff | 52 +++++- test/dynamo/cpython/3_13/test_dict.py | 3 + .../3_13/test_exception_variations.diff | 13 +- .../cpython/3_13/test_exception_variations.py | 3 + test/dynamo/cpython/3_13/test_exceptions.diff | 152 ++++++++++++++++++ test/dynamo/cpython/3_13/test_exceptions.py | 3 + test/dynamo/cpython/3_13/test_float.diff | 36 +++-- test/dynamo/cpython/3_13/test_float.py | 3 + .../cpython/3_13/test_generator_stop.diff | 9 +- .../cpython/3_13/test_generator_stop.py | 3 + test/dynamo/cpython/3_13/test_generators.diff | 27 ++-- test/dynamo/cpython/3_13/test_generators.py | 3 + test/dynamo/cpython/3_13/test_int.diff | 17 +- test/dynamo/cpython/3_13/test_int.py | 3 + .../dynamo/cpython/3_13/test_int_literal.diff | 11 +- test/dynamo/cpython/3_13/test_int_literal.py | 3 + test/dynamo/cpython/3_13/test_iter.diff | 13 +- test/dynamo/cpython/3_13/test_iter.py | 3 + test/dynamo/cpython/3_13/test_list.diff | 13 +- test/dynamo/cpython/3_13/test_list.py | 3 + test/dynamo/cpython/3_13/test_math.diff | 35 ++-- test/dynamo/cpython/3_13/test_math.py | 3 + .../cpython/3_13/test_ordered_dict.diff | 29 ++-- test/dynamo/cpython/3_13/test_ordered_dict.py | 3 + test/dynamo/cpython/3_13/test_raise.diff | 21 +-- test/dynamo/cpython/3_13/test_raise.py | 3 + test/dynamo/cpython/3_13/test_set.diff | 81 +++++----- test/dynamo/cpython/3_13/test_set.py | 3 + test/dynamo/cpython/3_13/test_sort.diff | 17 +- test/dynamo/cpython/3_13/test_sort.py | 3 + test/dynamo/cpython/3_13/test_sys.diff | 52 +++--- test/dynamo/cpython/3_13/test_sys.py | 3 + test/dynamo/cpython/3_13/test_tuple.diff | 9 +- test/dynamo/cpython/3_13/test_tuple.py | 3 + test/dynamo/cpython/3_13/test_userdict.diff | 9 +- test/dynamo/cpython/3_13/test_userdict.py | 3 + test/dynamo/cpython/3_13/test_userlist.diff | 9 +- test/dynamo/cpython/3_13/test_userlist.py | 3 + 52 files changed, 568 insertions(+), 216 deletions(-) diff --git a/test/dynamo/cpython/3_13/list_tests.diff b/test/dynamo/cpython/3_13/list_tests.diff index 903895b384b53..7889011f375dd 100644 --- a/test/dynamo/cpython/3_13/list_tests.diff +++ b/test/dynamo/cpython/3_13/list_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py -index dbc5ef4f9f2..239b75f74cc 100644 +index dbc5ef4f9f2..70e24036f74 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py -@@ -1,3 +1,53 @@ +@@ -1,3 +1,56 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -56,7 +59,7 @@ index dbc5ef4f9f2..239b75f74cc 100644 """ Tests common to list and UserList.UserList """ -@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList +@@ -5,7 +58,7 @@ Tests common to list and UserList.UserList import sys from functools import cmp_to_key @@ -65,7 +68,7 @@ index dbc5ef4f9f2..239b75f74cc 100644 from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit -@@ -119,10 +169,6 @@ class CommonTest(seq_tests.CommonTest): +@@ -119,10 +172,6 @@ class CommonTest(seq_tests.CommonTest): a[-1] = 9 self.assertEqual(a, self.type2test([5,6,7,8,9])) diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py index 239b75f74cc44..70e24036f74db 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/mapping_tests.diff b/test/dynamo/cpython/3_13/mapping_tests.diff index 03ae75513d664..009b53f31b55d 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.diff +++ b/test/dynamo/cpython/3_13/mapping_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py -index ed89a81a6ea..eed59a68e94 100644 +index ed89a81a6ea..10fc6e7e467 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py -@@ -1,10 +1,61 @@ +@@ -1,10 +1,64 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py ++ +import sys +import torch +import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py index eed59a68e9443..10fc6e7e46722 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/seq_tests.diff b/test/dynamo/cpython/3_13/seq_tests.diff index 03c7021e4f96a..b87c26ece27cb 100644 --- a/test/dynamo/cpython/3_13/seq_tests.diff +++ b/test/dynamo/cpython/3_13/seq_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py -index 719c9434a16..4325892276d 100644 +index 719c9434a16..2c502cda4f6 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 719c9434a16..4325892276d 100644 """ Tests common to tuple, list and UserList.UserList """ -@@ -95,7 +146,7 @@ class LyingList(list): +@@ -95,7 +149,7 @@ class LyingList(list): def __iter__(self): yield 1 diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py index 4325892276d4c..2c502cda4f617 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_baseexception.diff b/test/dynamo/cpython/3_13/test_baseexception.diff index b25d72d0f65dd..240e4e554d6ad 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.diff +++ b/test/dynamo/cpython/3_13/test_baseexception.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py -index e599b02c17d..3dc102e3b8a 100644 +index e599b02c17d..750d7a84fb4 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.py +++ b/test/dynamo/cpython/3_13/test_baseexception.py -@@ -1,10 +1,61 @@ +@@ -1,10 +1,64 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -65,7 +68,7 @@ index e599b02c17d..3dc102e3b8a 100644 """Tests for anything relating to exception objects themselves (e.g., inheritance hierarchy)""" -@@ -78,9 +129,6 @@ class ExceptionClassTests(unittest.TestCase): +@@ -78,9 +132,6 @@ class ExceptionClassTests(unittest.TestCase): last_depth = depth finally: inheritance_tree.close() @@ -75,7 +78,7 @@ index e599b02c17d..3dc102e3b8a 100644 self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set) interface_tests = ("length", "args", "str", "repr") -@@ -142,7 +190,7 @@ class ExceptionClassTests(unittest.TestCase): +@@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase): gc.collect() @@ -84,7 +87,7 @@ index e599b02c17d..3dc102e3b8a 100644 """Test usage of exceptions""" -@@ -208,5 +256,5 @@ class UsageTests(unittest.TestCase): +@@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase): self.catch_fails("spam") diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py index 3dc102e3b8a2e..750d7a84fb450 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.py +++ b/test/dynamo/cpython/3_13/test_baseexception.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_cmath.diff b/test/dynamo/cpython/3_13/test_cmath.diff index 7157e8c0498f6..c229add529029 100644 --- a/test/dynamo/cpython/3_13/test_cmath.diff +++ b/test/dynamo/cpython/3_13/test_cmath.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py -index a96a5780b31..883e87a0733 100644 +index a96a5780b31..37fb665d97d 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py -@@ -1,5 +1,55 @@ +@@ -1,5 +1,58 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -59,7 +62,7 @@ index a96a5780b31..883e87a0733 100644 from test.test_math import parse_testfile, test_file import test.test_math as test_math import unittest -@@ -50,7 +100,7 @@ complex_nans = [complex(x, y) for x, y in [ +@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [ (INF, NAN) ]] @@ -68,7 +71,7 @@ index a96a5780b31..883e87a0733 100644 # list of all functions in cmath test_functions = [getattr(cmath, fname) for fname in [ 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', -@@ -66,6 +116,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): def tearDown(self): self.test_values.close() @@ -108,7 +111,7 @@ index a96a5780b31..883e87a0733 100644 def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323, msg=None): """Fail if the two floating-point numbers are not almost equal. -@@ -590,4 +673,4 @@ class IsCloseTests(test_math.IsCloseTests): +@@ -590,4 +676,4 @@ class IsCloseTests(test_math.IsCloseTests): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py index 883e87a07337a..37fb665d97d26 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_complex.diff b/test/dynamo/cpython/3_13/test_complex.diff index a7867e47f2274..57a2d4315f21a 100644 --- a/test/dynamo/cpython/3_13/test_complex.diff +++ b/test/dynamo/cpython/3_13/test_complex.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py -index 6ff1a8ab29d..ab5bd3dab62 100644 +index 6ff1a8ab29d..cda348d2f37 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py -@@ -1,16 +1,143 @@ +@@ -1,16 +1,146 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -151,7 +154,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 INF = float("inf") NAN = float("nan") DBL_MAX = sys.float_info.max -@@ -45,7 +172,40 @@ class WithComplex: +@@ -45,7 +175,40 @@ class WithComplex: def __complex__(self): return self.value @@ -193,7 +196,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 def assertAlmostEqual(self, a, b): if isinstance(a, complex): -@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): # check that relative difference < eps self.assertTrue(abs((x-y)/y) < eps) @@ -223,7 +226,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 def assertClose(self, x, y, eps=1e-9): """Return true iff complexes x and y "are close".""" self.assertCloseAbs(x.real, y.real, eps) -@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -855,4 +1041,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py index ab5bd3dab62b2..cda348d2f3776 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_contextlib.diff b/test/dynamo/cpython/3_13/test_contextlib.diff index f3314f590c105..3850f66966817 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.diff +++ b/test/dynamo/cpython/3_13/test_contextlib.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py -index cf651959803..6a17bc719eb 100644 +index cf651959803..51fd083b112 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.py +++ b/test/dynamo/cpython/3_13/test_contextlib.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index cf651959803..6a17bc719eb 100644 """Unit tests for contextlib.py, and other context managers.""" import io -@@ -14,7 +65,7 @@ from test.support.testcase import ExceptionIsLikeMixin +@@ -14,7 +68,7 @@ from test.support.testcase import ExceptionIsLikeMixin import weakref @@ -66,7 +69,7 @@ index cf651959803..6a17bc719eb 100644 def test_enter(self): class DefaultEnter(AbstractContextManager): -@@ -67,7 +118,7 @@ class TestAbstractContextManager(unittest.TestCase): +@@ -67,7 +121,7 @@ class TestAbstractContextManager(unittest.TestCase): self.assertFalse(issubclass(NoExit, AbstractContextManager)) @@ -75,7 +78,7 @@ index cf651959803..6a17bc719eb 100644 def test_contextmanager_plain(self): state = [] -@@ -396,7 +447,7 @@ def woohoo(): +@@ -396,7 +450,7 @@ def woohoo(): self.assertEqual(depth, 0) @@ -84,7 +87,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -430,7 +481,7 @@ class ClosingTestCase(unittest.TestCase): +@@ -430,7 +484,7 @@ class ClosingTestCase(unittest.TestCase): self.assertEqual(state, [1]) @@ -93,7 +96,7 @@ index cf651959803..6a17bc719eb 100644 def test_nullcontext(self): class C: pass -@@ -439,7 +490,7 @@ class NullcontextTestCase(unittest.TestCase): +@@ -439,7 +493,7 @@ class NullcontextTestCase(unittest.TestCase): self.assertIs(c_in, c) @@ -102,7 +105,7 @@ index cf651959803..6a17bc719eb 100644 def testWithOpen(self): tfn = tempfile.mktemp() -@@ -457,7 +508,7 @@ class FileContextTestCase(unittest.TestCase): +@@ -457,7 +511,7 @@ class FileContextTestCase(unittest.TestCase): finally: os_helper.unlink(tfn) @@ -111,7 +114,7 @@ index cf651959803..6a17bc719eb 100644 def boilerPlate(self, lock, locked): self.assertFalse(locked()) -@@ -520,7 +571,7 @@ class mycontext(ContextDecorator): +@@ -520,7 +574,7 @@ class mycontext(ContextDecorator): return self.catch @@ -120,7 +123,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -680,7 +731,7 @@ class TestContextDecorator(unittest.TestCase): +@@ -680,7 +734,7 @@ class TestContextDecorator(unittest.TestCase): self.assertEqual(state, [1, 'something else', 999]) @@ -129,7 +132,7 @@ index cf651959803..6a17bc719eb 100644 exit_stack = None @support.requires_docstrings -@@ -1141,7 +1192,7 @@ class TestBaseExitStack: +@@ -1141,7 +1195,7 @@ class TestBaseExitStack: self.assertIs(exc.__cause__, exc.__context__) @@ -138,7 +141,7 @@ index cf651959803..6a17bc719eb 100644 exit_stack = ExitStack callback_error_internal_frames = [ ('__exit__', 'raise exc'), -@@ -1149,7 +1200,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase): +@@ -1149,7 +1203,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase): ] @@ -147,7 +150,7 @@ index cf651959803..6a17bc719eb 100644 redirect_stream = None orig_stream = None -@@ -1206,19 +1257,19 @@ class TestRedirectStream: +@@ -1206,19 +1260,19 @@ class TestRedirectStream: self.assertEqual(s, "Hello World!\n") @@ -170,7 +173,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -1315,7 +1366,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): +@@ -1315,7 +1369,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): ) @@ -179,7 +182,7 @@ index cf651959803..6a17bc719eb 100644 def make_relative_path(self, *parts): return os.path.join( os.path.dirname(os.path.realpath(__file__)), -@@ -1331,6 +1382,7 @@ class TestChdir(unittest.TestCase): +@@ -1331,6 +1385,7 @@ class TestChdir(unittest.TestCase): self.assertEqual(os.getcwd(), target) self.assertEqual(os.getcwd(), old_cwd) @@ -187,7 +190,7 @@ index cf651959803..6a17bc719eb 100644 def test_reentrant(self): old_cwd = os.getcwd() target1 = self.make_relative_path('data') -@@ -1363,4 +1415,4 @@ class TestChdir(unittest.TestCase): +@@ -1363,4 +1418,4 @@ class TestChdir(unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py index 6a17bc719eb94..51fd083b11294 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.py +++ b/test/dynamo/cpython/3_13/test_contextlib.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_dict.diff b/test/dynamo/cpython/3_13/test_dict.diff index 9589bcf797bd9..0c6beec66dad2 100644 --- a/test/dynamo/cpython/3_13/test_dict.diff +++ b/test/dynamo/cpython/3_13/test_dict.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py -index 4729132c5a5..14f829c1715 100644 +index 4c095464cbb..fcda6484ea6 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py -@@ -1,3 +1,57 @@ +@@ -1,3 +1,60 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -60,7 +63,7 @@ index 4729132c5a5..14f829c1715 100644 import collections import collections.abc import gc -@@ -11,7 +65,7 @@ from test import support +@@ -11,7 +68,7 @@ from test import support from test.support import import_helper, get_c_recursion_limit @@ -69,15 +72,48 @@ index 4729132c5a5..14f829c1715 100644 def test_invalid_keyword_arguments(self): class Custom(dict): -@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase): +@@ -265,39 +322,7 @@ class DictTest(unittest.TestCase): self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) +- def test_update_shared_keys(self): +- class MyClass: pass +- +- # Subclass str to enable us to create an object during the +- # dict.update() call. +- class MyStr(str): +- def __hash__(self): +- return super().__hash__() +- +- def __eq__(self, other): +- # Create an object that shares the same PyDictKeysObject as +- # obj.__dict__. +- obj2 = MyClass() +- obj2.a = "a" +- obj2.b = "b" +- obj2.c = "c" +- return super().__eq__(other) +- +- obj = MyClass() +- obj.a = "a" +- obj.b = "b" +- +- x = {} +- x[MyStr("a")] = MyStr("a") +- +- # gh-132617: this previously raised "dict mutated during update" error +- x.update(obj.__dict__) +- +- self.assertEqual(x, { +- MyStr("a"): "a", +- "b": "b", +- }) +- + @unittest.skip("test hangs") def test_fromkeys(self): self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) d = {} -@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase): +@@ -510,7 +535,7 @@ class DictTest(unittest.TestCase): for copymode in -1, +1: # -1: b has same structure as a # +1: b is a.copy() @@ -86,7 +122,7 @@ index 4729132c5a5..14f829c1715 100644 size = 2**log2size a = {} b = {} -@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase): +@@ -1039,18 +1064,6 @@ class DictTest(unittest.TestCase): pass self._tracked(MyDict()) @@ -105,7 +141,7 @@ index 4729132c5a5..14f829c1715 100644 def make_shared_key_dict(self, n): class C: pass -@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase): +@@ -1655,7 +1668,7 @@ class DictTest(unittest.TestCase): self.assertGreaterEqual(eq_count, 1) @@ -114,7 +150,7 @@ index 4729132c5a5..14f829c1715 100644 # Test _PyDict_GetItem_KnownHash() @support.cpython_only -@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +@@ -1699,4 +1712,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py index 14f829c1715c1..fcda6484ea607 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_exception_variations.diff b/test/dynamo/cpython/3_13/test_exception_variations.diff index 45424e087b5a1..52ae731d94934 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.diff +++ b/test/dynamo/cpython/3_13/test_exception_variations.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py -index a83a41d2975..be432089e3a 100644 +index a83a41d2975..c2d6eb3a41a 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.py +++ b/test/dynamo/cpython/3_13/test_exception_variations.py -@@ -1,7 +1,59 @@ +@@ -1,7 +1,62 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -53,17 +56,17 @@ index a83a41d2975..be432089e3a 100644 +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + -+ -+# ======= END DYNAMO PATCH ======= -class ExceptTestCases(unittest.TestCase): ++# ======= END DYNAMO PATCH ======= ++ +import unittest + +class ExceptTestCases(__TestCase): def test_try_except_else_finally(self): hit_except = False hit_else = False -@@ -294,282 +346,5 @@ class ExceptTestCases(unittest.TestCase): +@@ -294,282 +349,5 @@ class ExceptTestCases(unittest.TestCase): self.assertTrue(hit_except) diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py index be432089e3a33..c2d6eb3a41afc 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.py +++ b/test/dynamo/cpython/3_13/test_exception_variations.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_exceptions.diff b/test/dynamo/cpython/3_13/test_exceptions.diff index e69de29bb2d1d..6dcc9c858a9f8 100644 --- a/test/dynamo/cpython/3_13/test_exceptions.diff +++ b/test/dynamo/cpython/3_13/test_exceptions.diff @@ -0,0 +1,152 @@ +diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py +index c91f6662948..0ded70db3c7 100644 +--- a/test/dynamo/cpython/3_13/test_exceptions.py ++++ b/test/dynamo/cpython/3_13/test_exceptions.py +@@ -1,3 +1,59 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++ xfailIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Python test set -- part 5, built-in exceptions + + import copy +@@ -45,7 +101,7 @@ class BrokenStrException(Exception): + # XXX This is not really enough, each *operation* should be tested! + + +-class ExceptionTests(unittest.TestCase): ++class ExceptionTests(__TestCase): + + def raise_catch(self, exc, excname): + with self.subTest(exc=exc, excname=excname): +@@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase): + self.assertIn(b'MemoryError', err) + + +-class NameErrorTests(unittest.TestCase): ++class NameErrorTests(__TestCase): + def test_name_error_has_name(self): + try: + bluch +@@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase): + # Note: name suggestion tests live in `test_traceback`. + + +-class AttributeErrorTests(unittest.TestCase): ++class AttributeErrorTests(__TestCase): + def test_attributes(self): + # Setting 'attr' should not be a problem. + exc = AttributeError('Ouch!') +@@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase): + # Note: name suggestion tests live in `test_traceback`. + + +-class ImportErrorTests(unittest.TestCase): ++class ImportErrorTests(__TestCase): + + def test_attributes(self): + # Setting 'name' and 'path' should not be a problem. +@@ -2024,7 +2080,7 @@ def run_script(source): + _rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN) + return err.decode('utf-8').splitlines() + +-class AssertionErrorTests(unittest.TestCase): ++class AssertionErrorTests(__TestCase): + def tearDown(self): + unlink(TESTFN) + +@@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase): + + + @support.force_not_colorized_test_class +-class SyntaxErrorTests(unittest.TestCase): ++class SyntaxErrorTests(__TestCase): + maxDiff = None + + @force_not_colorized +@@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase): + err = run_script(b"\x89") + self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1]) + ++ + def test_string_source(self): + def try_compile(source): + with self.assertRaises(SyntaxError) as cm: +@@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase): + self.assertRaises(TypeError, SyntaxError, "bad bad", args) + + +-class TestInvalidExceptionMatcher(unittest.TestCase): ++class TestInvalidExceptionMatcher(__TestCase): + def test_except_star_invalid_exception_type(self): + with self.assertRaises(TypeError): + try: +@@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase): + pass + + +-class PEP626Tests(unittest.TestCase): ++class PEP626Tests(__TestCase): + + def lineno_after_raise(self, f, *expected): + try: +@@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase): + 1/0 + self.lineno_after_raise(after_with, 1, 1) + +-if __name__ == '__main__': +- unittest.main() ++if __name__ == "__main__": ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py index e6a9a2676bc00..0ded70db3c781 100644 --- a/test/dynamo/cpython/3_13/test_exceptions.py +++ b/test/dynamo/cpython/3_13/test_exceptions.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_float.diff b/test/dynamo/cpython/3_13/test_float.diff index 6b8586b1c6639..73cd65364fbc9 100644 --- a/test/dynamo/cpython/3_13/test_float.diff +++ b/test/dynamo/cpython/3_13/test_float.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py -index 97f951f1299..ce2c46777e0 100644 +index 87af79eb446..9313a1a63d7 100644 --- a/test/dynamo/cpython/3_13/test_float.py +++ b/test/dynamo/cpython/3_13/test_float.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 97f951f1299..ce2c46777e0 100644 import fractions import operator import os -@@ -8,11 +59,84 @@ import time +@@ -8,11 +62,84 @@ import time import unittest from test import support @@ -147,7 +150,7 @@ index 97f951f1299..ce2c46777e0 100644 from math import isinf, isnan, copysign, ldexp import math -@@ -35,7 +159,7 @@ class FloatSubclass(float): +@@ -35,7 +162,7 @@ class FloatSubclass(float): class OtherFloatSubclass(float): pass @@ -156,7 +159,7 @@ index 97f951f1299..ce2c46777e0 100644 def test_float(self): self.assertEqual(float(3.14), 3.14) -@@ -620,7 +744,7 @@ class GeneralFloatCases(unittest.TestCase): +@@ -620,7 +747,7 @@ class GeneralFloatCases(unittest.TestCase): @unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__") @@ -165,7 +168,7 @@ index 97f951f1299..ce2c46777e0 100644 def test_getformat(self): self.assertIn(float.__getformat__('double'), ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) -@@ -645,7 +769,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) +@@ -645,7 +772,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) # is accident (today). # let's also try to guarantee that -0.0 and 0.0 don't get confused. @@ -174,7 +177,7 @@ index 97f951f1299..ce2c46777e0 100644 @support.requires_IEEE_754 def test_double_specials_do_unpack(self): -@@ -670,7 +794,7 @@ class IEEEFormatTestCase(unittest.TestCase): +@@ -670,7 +797,7 @@ class IEEEFormatTestCase(unittest.TestCase): self.assertEqual(struct.pack("=": "issuperset", -@@ -1334,22 +1402,22 @@ class TestSubsets: +@@ -1334,22 +1405,22 @@ class TestSubsets: result = eval("x" + case + "y", locals()) self.assertEqual(result, expected) # Test the "friendly" method-name spelling, if one exists. @@ -321,7 +324,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set() right = set() name = "both empty" -@@ -1357,7 +1425,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): +@@ -1357,7 +1428,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -330,7 +333,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1, 2]) right = set([1, 2]) name = "equal pair" -@@ -1365,7 +1433,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): +@@ -1365,7 +1436,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -339,7 +342,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set() right = set([1, 2]) name = "one empty, one non-empty" -@@ -1373,7 +1441,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): +@@ -1373,7 +1444,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -348,7 +351,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1]) right = set([1, 2]) name = "one a non-empty proper subset of other" -@@ -1381,7 +1449,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): +@@ -1381,7 +1452,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -357,7 +360,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1]) right = set([2]) name = "neither empty, neither contains" -@@ -1389,7 +1457,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): +@@ -1389,7 +1460,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): #============================================================================== @@ -366,7 +369,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_eq_ne(self): # Unlike the others, this is testing that == and != *are* allowed. -@@ -1505,47 +1573,52 @@ class TestOnlySetsInBinaryOps: +@@ -1505,47 +1576,52 @@ class TestOnlySetsInBinaryOps: #------------------------------------------------------------------------------ @@ -425,7 +428,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def setUp(self): def gen(): for i in range(0, 10, 2): -@@ -1553,10 +1626,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): +@@ -1553,10 +1629,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): self.set = set((1, 2, 3)) self.other = gen() self.otherIsIterable = True @@ -438,7 +441,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_copy(self): dup = self.set.copy() -@@ -1577,40 +1651,46 @@ class TestCopying: +@@ -1577,40 +1654,46 @@ class TestCopying: #------------------------------------------------------------------------------ @@ -491,7 +494,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_binopsVsSubsets(self): a, b = self.a, self.b -@@ -1727,7 +1807,7 @@ def L(seqn): +@@ -1727,7 +1810,7 @@ def L(seqn): 'Test multiple tiers of iterators' return chain(map(lambda x:x, R(Ig(G(seqn))))) @@ -500,7 +503,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_constructor(self): for cons in (set, frozenset): -@@ -1785,7 +1865,7 @@ class bad_dict_clear: +@@ -1785,7 +1868,7 @@ class bad_dict_clear: def __hash__(self): return 0 @@ -509,7 +512,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_8420_set_merge(self): # This used to segfault global be_bad, set2, dict2 -@@ -1826,7 +1906,7 @@ class TestWeirdBugs(unittest.TestCase): +@@ -1826,7 +1909,7 @@ class TestWeirdBugs(unittest.TestCase): s.update(other) @@ -518,7 +521,7 @@ index d9102eb98a5..0b8e99a04c4 100644 """Regression test for bpo-46615""" constructor1 = None -@@ -1862,7 +1942,7 @@ class TestOperationsMutating: +@@ -1862,7 +1945,7 @@ class TestOperationsMutating: self.assertIn("changed size during iteration", str(e)) @@ -527,7 +530,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_eq_with_mutation(self): self.check_set_op_does_not_crash(lambda a, b: a == b) -@@ -1933,24 +2013,24 @@ class TestBinaryOpsMutating(TestOperationsMutating): +@@ -1933,24 +2016,24 @@ class TestBinaryOpsMutating(TestOperationsMutating): self.check_set_op_does_not_crash(f3) @@ -557,7 +560,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_issubset_with_mutation(self): self.check_set_op_does_not_crash(set.issubset) -@@ -1986,27 +2066,27 @@ class TestMethodsMutating(TestOperationsMutating): +@@ -1986,27 +2069,27 @@ class TestMethodsMutating(TestOperationsMutating): self.check_set_op_does_not_crash(set.update) @@ -591,7 +594,7 @@ index d9102eb98a5..0b8e99a04c4 100644 constructor1 = set constructor2 = list -@@ -2068,7 +2148,7 @@ def faces(G): +@@ -2068,7 +2151,7 @@ def faces(G): return f @@ -600,7 +603,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_cube(self): -@@ -2118,4 +2198,4 @@ class TestGraphs(unittest.TestCase): +@@ -2118,4 +2201,4 @@ class TestGraphs(unittest.TestCase): #============================================================================== if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py index 0b8e99a04c452..3543d60751e3c 100644 --- a/test/dynamo/cpython/3_13/test_set.py +++ b/test/dynamo/cpython/3_13/test_set.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_sort.diff b/test/dynamo/cpython/3_13/test_sort.diff index 78fde5ef19a1c..9049f28532518 100644 --- a/test/dynamo/cpython/3_13/test_sort.diff +++ b/test/dynamo/cpython/3_13/test_sort.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py -index 2a7cfb7affa..d661ae544b9 100644 +index 2a7cfb7affa..58b9b796362 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 2a7cfb7affa..d661ae544b9 100644 from test import support import random import unittest -@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None): +@@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None): nerrors += 1 return @@ -66,7 +69,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def testStressfully(self): # Try a variety of sizes at and around powers of 2, and at powers of 10. sizes = [0] -@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase): +@@ -151,7 +205,7 @@ class TestBase(unittest.TestCase): self.assertEqual(forced, native) #============================================================================== @@ -75,7 +78,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_bug453523(self): # bug 453523 -- list.sort() crasher. -@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase): +@@ -188,7 +242,7 @@ class TestBugs(unittest.TestCase): #============================================================================== @@ -84,7 +87,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_decorated(self): data = 'The quick Brown fox Jumped over The lazy Dog'.split() -@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L): +@@ -309,7 +363,7 @@ def check_against_PyObject_RichCompareBool(self, L): self.assertIs(opt, ref) #note: not assertEqual! We want to ensure *identical* behavior. @@ -93,7 +96,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_safe_object_compare(self): heterogeneous_lists = [[0, 'foo'], [0.0, 'foo'], -@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase): +@@ -408,4 +462,4 @@ class TestOptimizedCompares(unittest.TestCase): #============================================================================== if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py index d661ae544b992..58b9b79636227 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_sys.diff b/test/dynamo/cpython/3_13/test_sys.diff index 7fd0241560565..1c0cc65b36637 100644 --- a/test/dynamo/cpython/3_13/test_sys.diff +++ b/test/dynamo/cpython/3_13/test_sys.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py -index 72d51361e0b..0b4c6882e62 100644 +index 6b37094ed5f..c5e96a6a3dd 100644 --- a/test/dynamo/cpython/3_13/test_sys.py +++ b/test/dynamo/cpython/3_13/test_sys.py -@@ -1,3 +1,55 @@ +@@ -1,3 +1,58 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -58,7 +61,7 @@ index 72d51361e0b..0b4c6882e62 100644 import builtins import codecs import _datetime -@@ -35,7 +87,7 @@ def requires_subinterpreters(meth): +@@ -35,7 +90,7 @@ def requires_subinterpreters(meth): DICT_KEY_STRUCT_FORMAT = 'n2BI2n' @@ -67,7 +70,7 @@ index 72d51361e0b..0b4c6882e62 100644 def test_original_displayhook(self): dh = sys.__displayhook__ -@@ -81,19 +133,8 @@ class DisplayHookTest(unittest.TestCase): +@@ -81,19 +136,8 @@ class DisplayHookTest(unittest.TestCase): code = compile("42", "", "single") self.assertRaises(ValueError, eval, code) @@ -77,18 +80,18 @@ index 72d51361e0b..0b4c6882e62 100644 - sys.stdout = io.StringIO() - support.gc_collect() - return 'foo' -- + - with support.swap_attr(sys, 'stdout', None): - sys.stdout = io.StringIO() # the only reference - sys.displayhook(X()) # should not crash - +- - -class ActiveExceptionTests(unittest.TestCase): +class ActiveExceptionTests(__TestCase): def test_exc_info_no_exception(self): self.assertEqual(sys.exc_info(), (None, None, None)) -@@ -157,7 +198,7 @@ class ActiveExceptionTests(unittest.TestCase): +@@ -157,7 +201,7 @@ class ActiveExceptionTests(unittest.TestCase): self.assertIs(exc, e) @@ -97,7 +100,7 @@ index 72d51361e0b..0b4c6882e62 100644 @force_not_colorized def test_original_excepthook(self): -@@ -200,7 +241,7 @@ class ExceptHookTest(unittest.TestCase): +@@ -200,7 +244,7 @@ class ExceptHookTest(unittest.TestCase): # Python/pythonrun.c::PyErr_PrintEx() is tricky. @@ -106,7 +109,7 @@ index 72d51361e0b..0b4c6882e62 100644 def tearDown(self): test.support.reap_children() -@@ -500,6 +541,7 @@ class SysModuleTest(unittest.TestCase): +@@ -500,6 +544,7 @@ class SysModuleTest(unittest.TestCase): is sys._getframe().f_code ) @@ -114,16 +117,21 @@ index 72d51361e0b..0b4c6882e62 100644 def test_getframemodulename(self): # Default depth gets ourselves self.assertEqual(__name__, sys._getframemodulename()) -@@ -808,7 +850,7 @@ class SysModuleTest(unittest.TestCase): - self.assertRaises(TypeError, sys.intern, S("abc")) - if has_is_interned: - self.assertIs(sys._is_interned(S("abc")), False) -- -+ - @support.cpython_only - @requires_subinterpreters - def test_subinterp_intern_dynamically_allocated(self): -@@ -1359,7 +1401,7 @@ class SysModuleTest(unittest.TestCase): +@@ -894,7 +939,12 @@ class SysModuleTest(unittest.TestCase): + def assert_raise_on_new_sys_type(self, sys_attr): + # Users are intentionally prevented from creating new instances of + # sys.flags, sys.version_info, and sys.getwindowsversion. +- support.check_disallow_instantiation(self, type(sys_attr), sys_attr) ++ arg = sys_attr ++ attr_type = type(sys_attr) ++ with self.assertRaises(TypeError): ++ attr_type(arg) ++ with self.assertRaises(TypeError): ++ attr_type.__new__(attr_type, arg) + + def test_sys_flags_no_instantiation(self): + self.assert_raise_on_new_sys_type(sys.flags) +@@ -1354,7 +1404,7 @@ class SysModuleTest(unittest.TestCase): @test.support.cpython_only @@ -132,7 +140,7 @@ index 72d51361e0b..0b4c6882e62 100644 def test_original_unraisablehook(self): _testcapi = import_helper.import_module('_testcapi') from _testcapi import err_writeunraisable, err_formatunraisable -@@ -1516,7 +1558,7 @@ class UnraisableHookTest(unittest.TestCase): +@@ -1511,7 +1561,7 @@ class UnraisableHookTest(unittest.TestCase): @test.support.cpython_only @@ -141,7 +149,7 @@ index 72d51361e0b..0b4c6882e62 100644 def setUp(self): self.P = struct.calcsize('P') -@@ -1524,6 +1566,7 @@ class SizeofTest(unittest.TestCase): +@@ -1519,6 +1569,7 @@ class SizeofTest(unittest.TestCase): _testinternalcapi = import_helper.import_module("_testinternalcapi") self.gc_headsize = _testinternalcapi.SIZEOF_PYGC_HEAD self.managed_pre_header_size = _testinternalcapi.SIZEOF_MANAGED_PRE_HEADER @@ -149,7 +157,7 @@ index 72d51361e0b..0b4c6882e62 100644 check_sizeof = test.support.check_sizeof -@@ -1960,4 +2003,4 @@ class SizeofTest(unittest.TestCase): +@@ -1955,4 +2006,4 @@ class SizeofTest(unittest.TestCase): self.assertEqual(err, b"") if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py index f2d782127a485..c5e96a6a3dddf 100644 --- a/test/dynamo/cpython/3_13/test_sys.py +++ b/test/dynamo/cpython/3_13/test_sys.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_tuple.diff b/test/dynamo/cpython/3_13/test_tuple.diff index 46d4bb32d9efd..6e792b6c5450f 100644 --- a/test/dynamo/cpython/3_13/test_tuple.diff +++ b/test/dynamo/cpython/3_13/test_tuple.diff @@ -1,8 +1,8 @@ diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py -index 9ce80c5e8ea..e52c0cbc140 100644 +index 9ce80c5e8ea..c6eab3ff1e9 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py -@@ -1,4 +1,55 @@ +@@ -1,4 +1,58 @@ -from test import support, seq_tests +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] @@ -10,6 +10,9 @@ index 9ce80c5e8ea..e52c0cbc140 100644 +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -59,7 +62,7 @@ index 9ce80c5e8ea..e52c0cbc140 100644 import unittest import gc -@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest): +@@ -510,4 +564,4 @@ class TupleTest(seq_tests.CommonTest): # pileup 262,143 mean 8.0 coll 262,143 z +92683.6 if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py index e52c0cbc14030..c6eab3ff1e92c 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_userdict.diff b/test/dynamo/cpython/3_13/test_userdict.diff index 1c01574892067..8b8101ae9091d 100644 --- a/test/dynamo/cpython/3_13/test_userdict.diff +++ b/test/dynamo/cpython/3_13/test_userdict.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py -index 61e79f553e8..c953390355e 100644 +index 61e79f553e8..75b789633ed 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 61e79f553e8..c953390355e 100644 # Check every path through every method of UserDict from test import mapping_tests, support -@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): +@@ -215,10 +269,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): # Decorate existing test with recursion limit, because # the test is for C structure, but `UserDict` is a Python structure. diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py index c953390355e67..75b789633edf0 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_userlist.diff b/test/dynamo/cpython/3_13/test_userlist.diff index 299a8abeb99ac..20999ba6bca0f 100644 --- a/test/dynamo/cpython/3_13/test_userlist.diff +++ b/test/dynamo/cpython/3_13/test_userlist.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py -index 312702c8e39..a4532922f5d 100644 +index 312702c8e39..5ede0c3b7f1 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py -@@ -1,7 +1,58 @@ +@@ -1,7 +1,61 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -62,7 +65,7 @@ index 312702c8e39..a4532922f5d 100644 import unittest from test import support -@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest): +@@ -69,9 +123,9 @@ class UserListTest(list_tests.CommonTest): # Decorate existing test with recursion limit, because # the test is for C structure, but `UserList` is a Python structure. diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py index a4532922f5d42..5ede0c3b7f1a0 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py + import sys import torch import torch._dynamo.test_case From e8cca7bac7553af0efe208d40c1cbaab72797ad9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Jul 2025 16:33:48 +0000 Subject: [PATCH 005/457] Revert "Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165)" This reverts commit 85857181ebca86e9c709e9922a9d9ef41a9c4ef9. Reverted https://github.com/pytorch/pytorch/pull/156165 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing to build PyTorch internally ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3070218901)) --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 2 +- c10/cuda/CUDAAllocatorConfig.h | 19 ++------- c10/cuda/CUDACachingAllocator.cpp | 47 ++++++++++----------- c10/xpu/XPUCachingAllocator.cpp | 3 +- torch/csrc/cuda/Module.cpp | 5 ++- 5 files changed, 32 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index b5e5f84cde13f..6a80342e10240 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl } bool pinned_use_background_threads() override { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: + return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: pinned_use_background_threads(); } diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 9aa128c26bd0c..8b425e77b9f0d 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -2,7 +2,6 @@ #include #include -#include #include #include @@ -17,13 +16,9 @@ enum class Expandable_Segments_Handle_Type : int { // Environment config parser class C10_CUDA_API CUDAAllocatorConfig { public: - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.") static size_t max_split_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.") static double garbage_collection_threshold() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: garbage_collection_threshold(); @@ -64,8 +59,6 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.") static bool pinned_use_background_threads() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: pinned_use_background_threads(); @@ -78,29 +71,25 @@ class C10_CUDA_API CUDAAllocatorConfig { return 128; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") + // This is used to round-up allocation size to nearest power of 2 divisions. + // More description below in function roundup_power2_next_division + // As an example, if we want 4 divisions between 2's power, this can be done + // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 static size_t roundup_power2_divisions(size_t size) { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(size); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static std::vector roundup_power2_divisions() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.") static size_t max_non_split_rounding_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: max_non_split_rounding_size(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.") static std::string last_allocator_settings() { return c10::CachingAllocator::getAllocatorSettings(); } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 5ae04bcd3f53c..ed6914c350599 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1226,7 +1226,7 @@ class DeviceCachingAllocator { DeviceCachingAllocator() : large_blocks(/*small=*/false), small_blocks(/*small=*/true) { stats.max_split_size = - static_cast(AcceleratorAllocatorConfig::max_split_size()); + static_cast(CUDAAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); } @@ -1351,8 +1351,7 @@ class DeviceCachingAllocator { // Do garbage collection if the flag is set. if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > - 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); } // Attempt allocate @@ -1604,7 +1603,7 @@ class DeviceCachingAllocator { stats.active_bytes[stat_type].increase(block->size); stats.requested_bytes[stat_type].increase(block->requested_size); }); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.increase(1); auto allocated_bytes_gauge = @@ -1655,7 +1654,7 @@ class DeviceCachingAllocator { block->pool->owner_MempoolId(), context ? context : block->context_when_allocated); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { @@ -2205,8 +2204,7 @@ class DeviceCachingAllocator { if (size < kMinBlockSize) { return kMinBlockSize; } else { - auto divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(size); + auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size); if (divisions > 1 && size > (kMinBlockSize * divisions)) { return roundup_power2_next_division(size, divisions); } else { @@ -2696,7 +2694,7 @@ class DeviceCachingAllocator { if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) { return remaining >= kMinBlockSize; } else { - return (size < AcceleratorAllocatorConfig::max_split_size()) && + return (size < CUDAAllocatorConfig::max_split_size()) && (remaining > kSmallSize); } } @@ -2716,7 +2714,7 @@ class DeviceCachingAllocator { if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; } @@ -2758,13 +2756,13 @@ class DeviceCachingAllocator { } // Do not return an oversized block for a large request - if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size())) + if ((p.size() < CUDAAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size())) return false; // Allow oversized block size to be rounded up but within a limit - if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) && + if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && ((*it)->size >= - p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size())) + p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); @@ -2787,7 +2785,7 @@ class DeviceCachingAllocator { // therefore should be of less overheads. size_t gc_threshold = static_cast( - AcceleratorAllocatorConfig::garbage_collection_threshold() * + CUDAAllocatorConfig::garbage_collection_threshold() * static_cast(allowed_memory_maximum)); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { @@ -2935,7 +2933,7 @@ class DeviceCachingAllocator { stats.segment[stat_type].increase(1); stats.reserved_bytes[stat_type].increase(size); }); - if (size >= AcceleratorAllocatorConfig::max_split_size()) + if (size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.increase(1); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2964,7 +2962,7 @@ class DeviceCachingAllocator { bool release_available_cached_blocks( const AllocParams& p, const std::shared_ptr& context) { - if (AcceleratorAllocatorConfig::max_split_size() == + if (CUDAAllocatorConfig::max_split_size() == std::numeric_limits::max()) return false; BlockPool& pool = *p.pool; @@ -2972,8 +2970,8 @@ class DeviceCachingAllocator { // because of std::unique_ptr, block cannot be trivially copied // Use constructor for search key. Block key(p.search_key.device, p.search_key.stream, p.search_key.size); - key.size = (key.size < AcceleratorAllocatorConfig::max_split_size()) - ? AcceleratorAllocatorConfig::max_split_size() + key.size = (key.size < CUDAAllocatorConfig::max_split_size()) + ? CUDAAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); if (it == pool.blocks.end() || (*it)->stream != p.stream() || @@ -2986,7 +2984,7 @@ class DeviceCachingAllocator { --it; // Back up one item. Now on the largest block for the correct // stream while ((totalReleased < key.size) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; bool is_first = cur == pool.blocks.begin(); @@ -3111,7 +3109,7 @@ class DeviceCachingAllocator { stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; @@ -3738,8 +3736,8 @@ class NativeCachingAllocator : public CUDAAllocator { auto& md = result.config_metadata; md.garbage_collection_threshold = - AcceleratorAllocatorConfig::garbage_collection_threshold(); - md.max_split_size = AcceleratorAllocatorConfig::max_split_size(); + CUDAAllocatorConfig::garbage_collection_threshold(); + md.max_split_size = CUDAAllocatorConfig::max_split_size(); md.pinned_num_register_threads = CUDAAllocatorConfig::pinned_num_register_threads(); md.expandable_segments = CUDAAllocatorConfig::expandable_segments(); @@ -3747,10 +3745,9 @@ class NativeCachingAllocator : public CUDAAllocator { CUDAAllocatorConfig::release_lock_on_cudamalloc(); md.pinned_use_host_register = CUDAAllocatorConfig::pinned_use_cuda_host_register(); - md.last_allocator_settings = - AcceleratorAllocatorConfig::last_allocator_settings(); + md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings(); md.roundup_power2_divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(); + CUDAAllocatorConfig::roundup_power2_divisions(); return result; } diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b4..543b48f081135 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -21,6 +20,8 @@ constexpr size_t kMinBlockSize = 512; constexpr size_t kSmallSize = 1048576; // "small" allocations are packed in 2 MiB blocks constexpr size_t kSmallBuffer = 2097152; +// "large" allocations may be packed in 20 MiB blocks +constexpr size_t kLargeBuffer = 20971520; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kMinLargeAlloc = 10485760; // round up large allocations to 2 MiB diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index ead46337ff090..b44ce311ecd92 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -426,7 +426,8 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings( PyObject* _unused, PyObject* env) { HANDLE_TH_ERRORS - c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env)); + c10::cuda::CUDACachingAllocator::setAllocatorSettings( + THPUtils_unpackString(env)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } From 6fe7456aa1a2d025d1d06e15ba3896e6adba94b8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Jul 2025 16:33:48 +0000 Subject: [PATCH 006/457] Revert "Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)" This reverts commit 03b307575a98dc1d953c9d3521a9489e0e61e70c. Reverted https://github.com/pytorch/pytorch/pull/150312 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing to build PyTorch internally ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3070218901)) --- c10/cuda/CUDAAllocatorConfig.cpp | 471 ++++++++++++++++++++++++------ c10/cuda/CUDAAllocatorConfig.h | 129 ++++---- c10/cuda/CUDACachingAllocator.cpp | 50 +++- c10/cuda/CUDACachingAllocator.h | 4 +- 4 files changed, 495 insertions(+), 159 deletions(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index ac59059c2cc21..d2efb8c593e44 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -1,121 +1,389 @@ #include -#include -#include +#include +#include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif -#include - namespace c10::cuda::CUDACachingAllocator { -size_t CUDAAllocatorConfig::parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, +constexpr size_t kRoundUpPowerOfTwoIntervals = 16; + +CUDAAllocatorConfig::CUDAAllocatorConfig() + : m_max_split_size(std::numeric_limits::max()), + m_max_non_split_rounding_size(kLargeBuffer), + m_garbage_collection_threshold(0), + m_pinned_num_register_threads(1), + m_expandable_segments(false), +#if CUDA_VERSION >= 12030 + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::UNSPECIFIED), +#else + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::POSIX_FD), +#endif + m_release_lock_on_cudamalloc(false), + m_pinned_use_cuda_host_register(false), + m_pinned_use_background_threads(false) { + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); +} + +size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) { + size_t log_size = (63 - llvm::countLeadingZeros(size)); + + // Our intervals start at 1MB and end at 64GB + const size_t interval_start = + 63 - llvm::countLeadingZeros(static_cast(1048576)); + const size_t interval_end = + 63 - llvm::countLeadingZeros(static_cast(68719476736)); + TORCH_CHECK( + (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), + "kRoundUpPowerOfTwoIntervals mismatch"); + + int index = static_cast(log_size) - static_cast(interval_start); + + index = std::max(0, index); + index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); + return instance().m_roundup_power2_divisions[index]; +} + +void CUDAAllocatorConfig::lexArgs( + const std::string& env, + std::vector& config) { + std::vector buf; + + for (char ch : env) { + if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + buf.clear(); + } + config.emplace_back(1, ch); + } else if (ch != ' ') { + buf.emplace_back(ch); + } + } + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + } +} + +void CUDAAllocatorConfig::consumeToken( + const std::vector& config, + size_t i, + const char c) { + TORCH_CHECK( + i < config.size() && config[i] == std::string(1, c), + "Error parsing CachingAllocator settings, expected ", + c, + ""); +} + +size_t CUDAAllocatorConfig::parseMaxSplitSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_split_size_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_split_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_non_split_rounding_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_non_split_rounding_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + double val1 = stod(config[i]); + TORCH_CHECK( + val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); + TORCH_CHECK( + val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); + m_garbage_collection_threshold = val1; + } else { + TORCH_CHECK( + false, "Error, expecting garbage_collection_threshold value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( + const std::vector& config, size_t i) { + consumeToken(config, ++i, ':'); + bool first_value = true; + + if (++i < config.size()) { + if (std::string_view(config[i]) == "[") { + size_t last_index = 0; + // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) + while (++i < config.size() && std::string_view(config[i]) != "]") { + const std::string& val1 = config[i]; + size_t val2 = 0; + + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + val2 = stoi(config[i]); + } else { + TORCH_CHECK( + false, "Error parsing roundup_power2_divisions value", ""); + } + TORCH_CHECK( + val2 == 0 || llvm::isPowerOf2_64(val2), + "For roundups, the divisions has to be power of 2 or 0 to disable roundup ", + ""); + + if (std::string_view(val1) == ">") { + std::fill( + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + last_index)), + m_roundup_power2_divisions.end(), + val2); + } else { + size_t val1_long = stoul(val1); + TORCH_CHECK( + llvm::isPowerOf2_64(val1_long), + "For roundups, the intervals have to be power of 2 ", + ""); + + size_t index = 63 - llvm::countLeadingZeros(val1_long); + index = std::max((size_t)0, index); + index = std::min(index, m_roundup_power2_divisions.size() - 1); + + if (first_value) { + std::fill( + m_roundup_power2_divisions.begin(), + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + index)), + val2); + first_value = false; + } + if (index < m_roundup_power2_divisions.size()) { + m_roundup_power2_divisions[index] = val2; + } + last_index = index; + } + + if (std::string_view(config[i + 1]) != "]") { + consumeToken(config, ++i, ','); + } + } + } else { // Keep this for backwards compatibility + size_t val1 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val1), + "For roundups, the divisions has to be power of 2 ", + ""); + std::fill( + m_roundup_power2_divisions.begin(), + m_roundup_power2_divisions.end(), + val1); + } + } else { + TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync) { // For ease of maintenance and understanding, the CUDA and ROCm // implementations of this function are separated. This avoids having many // #ifdef's throughout. +#ifdef USE_ROCM // Ease burden on ROCm users by allowing either cuda or hip tokens. // cuda token is broken up to prevent hipify matching it. #define PYTORCH_TOKEN1 \ "cud" \ "aMallocAsync" #define PYTORCH_TOKEN2 "hipMallocAsync" - tokenizer.checkToken(++i, ":"); - i++; // Move to the value after the colon - TORCH_CHECK( - ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) || - (tokenizer[i] == PYTORCH_TOKEN2)), - "Unknown allocator backend, " - "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); - if (m_is_allocator_loaded) { - bool aync_allocator_at_runtime = (tokenizer[i] != "native"); + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - aync_allocator_at_runtime == m_use_async_allocator, - "Allocator async backend parsed at runtime != allocator async backend parsed at load time, ", - aync_allocator_at_runtime, + ((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) || + (config[i] == PYTORCH_TOKEN2)), + "Unknown allocator backend, " + "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); + used_cudaMallocAsync = + (config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2); + TORCH_INTERNAL_ASSERT( + config[i] == get()->name() || + (config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time, ", + config[i], " != ", - m_use_async_allocator); + get()->name()); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } - m_use_async_allocator = - (tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2); - // CUDA allocator is always loaded at the start of the program - m_is_allocator_loaded = true; - -#if defined(CUDA_VERSION) - if (m_use_async_allocator) { -#if CUDA_VERSION >= 11040 - int version = 0; - C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + return i; +#undef PYTORCH_TOKEN1 +#undef PYTORCH_TOKEN2 +#else // USE_ROCM + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - version >= 11040, - "backend:cudaMallocAsync requires CUDA runtime " - "11.4 or newer, but cudaDriverGetVersion returned ", - version); + ((config[i] == "native") || (config[i] == "cudaMallocAsync")), + "Unknown allocator backend, " + "options are native and cudaMallocAsync"); + used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); + if (used_cudaMallocAsync) { +#if CUDA_VERSION >= 11040 + int version = 0; + C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + TORCH_CHECK( + version >= 11040, + "backend:cudaMallocAsync requires CUDA runtime " + "11.4 or newer, but cudaDriverGetVersion returned ", + version); #else - TORCH_CHECK( - false, - "backend:cudaMallocAsync requires PyTorch to be built with " - "CUDA 11.4 or newer, but CUDA_VERSION is ", - CUDA_VERSION); + TORCH_CHECK( + false, + "backend:cudaMallocAsync requires PyTorch to be built with " + "CUDA 11.4 or newer, but CUDA_VERSION is ", + CUDA_VERSION); #endif + } + TORCH_INTERNAL_ASSERT( + config[i] == get()->name(), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time"); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } -#endif - return i; -#undef PYTORCH_TOKEN1 -#undef PYTORCH_TOKEN2 +#endif // USE_ROCM } -void CUDAAllocatorConfig::parseArgs(const std::string& env) { +void CUDAAllocatorConfig::parseArgs(const std::optional& env) { // If empty, set the default values + m_max_split_size = std::numeric_limits::max(); + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); + m_garbage_collection_threshold = 0; + bool used_cudaMallocAsync = false; bool used_native_specific_option = false; - c10::CachingAllocator::ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "backend") { - i = parseAllocatorConfig(tokenizer, i); + if (!env.has_value()) { + return; + } + { + std::lock_guard lock(m_last_allocator_settings_mutex); + m_last_allocator_settings = env.value(); + } + + std::vector config; + lexArgs(env.value(), config); + + for (size_t i = 0; i < config.size(); i++) { + std::string_view config_item_view(config[i]); + if (config_item_view == "max_split_size_mb") { + i = parseMaxSplitSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "max_non_split_rounding_mb") { + i = parseMaxNonSplitRoundingSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "garbage_collection_threshold") { + i = parseGarbageCollectionThreshold(config, i); + used_native_specific_option = true; + } else if (config_item_view == "roundup_power2_divisions") { + i = parseRoundUpPower2Divisions(config, i); + used_native_specific_option = true; + } else if (config_item_view == "backend") { + i = parseAllocatorConfig(config, i, used_cudaMallocAsync); + } else if (config_item_view == "expandable_segments") { + used_native_specific_option = true; + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for expandable_segments"); + config_item_view = config[i]; + m_expandable_segments = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "release_lock_on_hipmalloc" || - key == + config_item_view == "release_lock_on_hipmalloc" || + config_item_view == "release_lock_on_c" "udamalloc") { used_native_specific_option = true; - tokenizer.checkToken(++i, ":"); - m_release_lock_on_cudamalloc = tokenizer.toBool(++i); + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for release_lock_on_cudamalloc"); + config_item_view = config[i]; + m_release_lock_on_cudamalloc = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "pinned_use_hip_host_register" || - key == + config_item_view == "pinned_use_hip_host_register" || + config_item_view == "pinned_use_c" "uda_host_register") { - i = parsePinnedUseCudaHostRegister(tokenizer, i); + i = parsePinnedUseCudaHostRegister(config, i); used_native_specific_option = true; - } else if (key == "pinned_num_register_threads") { - i = parsePinnedNumRegisterThreads(tokenizer, i); + } else if (config_item_view == "pinned_num_register_threads") { + i = parsePinnedNumRegisterThreads(config, i); + used_native_specific_option = true; + } else if (config_item_view == "pinned_use_background_threads") { + i = parsePinnedUseBackgroundThreads(config, i); used_native_specific_option = true; } else { - const auto& keys = - c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); TORCH_CHECK( - keys.find(key) != keys.end(), - "Unrecognized key '", - key, - "' in Accelerator allocator config."); - i = tokenizer.skipKey(i); + false, "Unrecognized CachingAllocator option: ", config_item_view); } - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); + if (i + 1 < config.size()) { + consumeToken(config, ++i, ','); } } - if (m_use_async_allocator && used_native_specific_option) { + if (used_cudaMallocAsync && used_native_specific_option) { TORCH_WARN( "backend:cudaMallocAsync ignores max_split_size_mb," "roundup_power2_divisions, and garbage_collect_threshold."); @@ -123,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - m_pinned_use_cuda_host_register = tokenizer.toBool(++i); - + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_cuda_host_register"); + m_pinned_use_cuda_host_register = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_cuda_host_register value", ""); + } return i; } size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - size_t val2 = tokenizer.toSizeT(++i); - TORCH_CHECK( - llvm::isPowerOf2_64(val2), - "Number of register threads has to be power of 2 ", - ""); - auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); - TORCH_CHECK( - val2 <= maxThreads, - "Number of register threads should be less than or equal to " + - std::to_string(maxThreads), - ""); - m_pinned_num_register_threads = val2; + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val2 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; + } else { + TORCH_CHECK( + false, "Error, expecting pinned_num_register_threads value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_background_threads"); + m_pinned_use_background_threads = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_background_threads value", ""); + } return i; } -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig) +// General caching allocator utilities +void setAllocatorSettings(const std::string& env) { + CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); +} } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 8b425e77b9f0d..fda3cc02e5d0a 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -1,10 +1,16 @@ #pragma once -#include #include #include #include +#include +#include +#include +#include +#include +#include + namespace c10::cuda::CUDACachingAllocator { enum class Expandable_Segments_Handle_Type : int { @@ -17,23 +23,20 @@ enum class Expandable_Segments_Handle_Type : int { class C10_CUDA_API CUDAAllocatorConfig { public: static size_t max_split_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); + return instance().m_max_split_size; } static double garbage_collection_threshold() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - garbage_collection_threshold(); + return instance().m_garbage_collection_threshold; } static bool expandable_segments() { - bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: - use_expandable_segments(); #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED - if (enabled) { + if (instance().m_expandable_segments) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } return false; #else - return enabled; + return instance().m_expandable_segments; #endif } @@ -60,8 +63,7 @@ class C10_CUDA_API CUDAAllocatorConfig { } static bool pinned_use_background_threads() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - pinned_use_background_threads(); + return instance().m_pinned_use_background_threads; } static size_t pinned_max_register_threads() { @@ -75,97 +77,88 @@ class C10_CUDA_API CUDAAllocatorConfig { // More description below in function roundup_power2_next_division // As an example, if we want 4 divisions between 2's power, this can be done // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 - static size_t roundup_power2_divisions(size_t size) { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(size); - } + static size_t roundup_power2_divisions(size_t size); static std::vector roundup_power2_divisions() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(); + return instance().m_roundup_power2_divisions; } static size_t max_non_split_rounding_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - max_non_split_rounding_size(); + return instance().m_max_non_split_rounding_size; } static std::string last_allocator_settings() { - return c10::CachingAllocator::getAllocatorSettings(); - } - - static bool use_async_allocator() { - return instance().m_use_async_allocator; - } - - static const std::unordered_set& getKeys() { - return instance().keys_; + std::lock_guard lock( + instance().m_last_allocator_settings_mutex); + return instance().m_last_allocator_settings; } static CUDAAllocatorConfig& instance() { static CUDAAllocatorConfig* s_instance = ([]() { auto inst = new CUDAAllocatorConfig(); - auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF"); - if (!env.has_value()) { - // For backward compatibility, check for the old environment variable - // PYTORCH_CUDA_ALLOC_CONF. - env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); - } + auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); #ifdef USE_ROCM // convenience for ROCm users, allow alternative HIP token if (!env.has_value()) { env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); } #endif - if (env.has_value()) { - inst->parseArgs(env.value()); - } + inst->parseArgs(env); return inst; })(); return *s_instance; } - void parseArgs(const std::string& env); + void parseArgs(const std::optional& env); private: - CUDAAllocatorConfig() = default; - - size_t parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + CUDAAllocatorConfig(); + + static void lexArgs(const std::string& env, std::vector& config); + static void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync); size_t parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i); size_t parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, + size_t i); + size_t parsePinnedUseBackgroundThreads( + const std::vector& config, size_t i); - std::atomic m_pinned_num_register_threads{1}; - std::atomic m_expandable_segments_handle_type -#if CUDA_VERSION >= 12030 - {Expandable_Segments_Handle_Type::UNSPECIFIED}; -#else - {Expandable_Segments_Handle_Type::POSIX_FD}; -#endif - std::atomic m_release_lock_on_cudamalloc{false}; - std::atomic m_pinned_use_cuda_host_register{false}; - std::atomic m_use_async_allocator{false}; - std::atomic m_is_allocator_loaded{false}; - std::unordered_set keys_{ - "backend", - // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues - // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_cud" - "amalloc", - "pinned_use_cud" - "a_host_register", - // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_hipmalloc", - "pinned_use_hip_host_register", - "pinned_num_register_threads"}; + std::atomic m_max_split_size; + std::atomic m_max_non_split_rounding_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; + std::atomic m_expandable_segments; + std::atomic + m_expandable_segments_handle_type; + std::atomic m_release_lock_on_cudamalloc; + std::atomic m_pinned_use_cuda_host_register; + std::atomic m_pinned_use_background_threads; + std::string m_last_allocator_settings; + std::mutex m_last_allocator_settings_mutex; }; -// Keep this for backwards compatibility -using c10::CachingAllocator::setAllocatorSettings; +// General caching allocator utilities +C10_CUDA_API void setAllocatorSettings(const std::string& env); } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index ed6914c350599..4d58c11c5c9bc 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator { using namespace c10::CachingAllocator; using namespace c10::CachingDeviceAllocator; +// Included here as this is externally used in CUDAAllocatorConfig +const size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks + namespace Native { // @@ -4125,10 +4130,49 @@ CUDAAllocator* allocator(); } // namespace CudaMallocAsync struct BackendStaticInitializer { + // Parses env for backend at load time, duplicating some logic from + // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at + // runtime). Defers verbose exceptions and error checks, including Cuda + // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this + // works, maybe we should move all of CUDAAllocatorConfig here? CUDAAllocator* parseEnvForBackend() { - // If the environment variable is set, we use the CudaMallocAsync allocator. - if (CUDAAllocatorConfig::use_async_allocator()) { - return CudaMallocAsync::allocator(); + auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (!val.has_value()) { + val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); + } +#endif + if (val.has_value()) { + const std::string& config = val.value(); + + std::regex exp("[\\s,]+"); + std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); + std::sregex_token_iterator end; + std::vector options(it, end); + + for (auto option : options) { + std::regex exp2("[:]+"); + std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); + std::sregex_token_iterator end2; + std::vector kv(it2, end2); + if (kv.size() >= 2) { + if (kv[0] == "backend") { +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (kv[1] == + "cud" + "aMallocAsync" || + kv[1] == "hipMallocAsync") +#else + if (kv[1] == "cudaMallocAsync") +#endif + return CudaMallocAsync::allocator(); + if (kv[1] == "native") + return &Native::allocator; + } + } + } } return &Native::allocator; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe22827..a6fa61110d675 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator { // Preserved only for BC reasons // NOLINTNEXTLINE(misc-unused-using-decls) -using c10::CachingAllocator::kLargeBuffer; using c10::CachingDeviceAllocator::DeviceStats; +extern const size_t kLargeBuffer; + typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a From 6ea91f067256447cda6fae533f806c1f8baafbe2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Jul 2025 16:58:13 +0000 Subject: [PATCH 007/457] Revert "[Inductor] Set the default value of min_chunk_size to 512 (#150762)" This reverts commit 3321acc92e24859dbe2ac6499067d1afde5622c3. Reverted https://github.com/pytorch/pytorch/pull/150762 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but an inductor compilation error shows up in trunk ([comment](https://github.com/pytorch/pytorch/pull/150762#issuecomment-3070286787)) --- .../_inductor/codegen/cpp_template_kernel.py | 23 +++---------------- torch/_inductor/config.py | 2 +- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 184c0fe889af9..b7a830a501051 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -2,7 +2,6 @@ import itertools from collections.abc import Iterable from typing import Any, Callable, Optional, Union -from unittest.mock import patch import sympy from sympy.parsing.sympy_parser import parse_expr @@ -19,7 +18,7 @@ from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V from .common import REMOVED -from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth +from .cpp import CppKernel, CppKernelProxy, KernelGroup from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext @@ -289,15 +288,7 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - - def max_parallel_depth(): - return ParallelDepth(parallel_depth=0, start_depth=0) - - # This loop is not parallelized since it is not the outermost loop. - with patch.object( - cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth - ): - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) return kernel_group.loops_code.getvalue() def store_grouped_gemm_pointwise_nodes( @@ -351,15 +342,7 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - - def max_parallel_depth(): - return ParallelDepth(parallel_depth=0, start_depth=0) - - # This loop is not parallelized since it is not the outermost loop. - with patch.object( - cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth - ): - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) return kernel_group.loops_code.getvalue() def store_output( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 293c1b2333436..60e8b259368cf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1019,7 +1019,7 @@ class cpp: dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" simdlen: Optional[int] = None - min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512")) + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) cxx: tuple[Literal[None], str] = ( None, # download gcc12 from conda-forge if conda is installed From 9b0013c6bb98d7161e921d03be76c81bbc0eebef Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 14 Jul 2025 17:35:58 +0000 Subject: [PATCH 008/457] [CI] Update mobile build docker image (#158153) The docker image got removed and then the job started building its own -> takes a long time I don't know why it uses the asan image image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158153 Approved by: https://github.com/Skylion007 --- .github/workflows/pull.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 53a4f6357e5c2..59a7265173800 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -315,14 +315,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3-clang12-mobile-build: - name: linux-jammy-py3-clang12-mobile-build + linux-jammy-py3-clang18-mobile-build: + name: linux-jammy-py3-clang18-mobile-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-mobile-build - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan build-generates-artifacts: false test-matrix: | { include: [ From fb462cec8d8674ad547c55dbe90710bde1dc2019 Mon Sep 17 00:00:00 2001 From: James Wu Date: Sun, 13 Jul 2025 12:51:11 -0700 Subject: [PATCH 009/457] Normalize placeholder names in AOTAutogradCache (#157916) This PR adds a pass to sanitize_gm_for_cache which normalizes all placeholder names across input dynamo graphs to AOTAutogradCache. This is safe because nothing underneath AOTAutograd uses the node names on the original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are in terms of original sources rather than placeholder names. Note that the dynamo output graphs traced by tlparse will not show this change because it's done before this sanitization step. The aot autograd outputs also will not change because AOTAutograd's own traced graphs don't use the original placeholders of the dynamo graph. Thus, this change is essentially a no-op from everyone's perspective except for cache key checks. Fixes #157792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157916 Approved by: https://github.com/zou3519 --- test/inductor/test_codecache.py | 9 +++ .../_aot_autograd/autograd_cache.py | 73 ++++++++++++++++--- torch/_functorch/config.py | 4 + 3 files changed, 75 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 68626aafcc0d9..51af64153500d 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1854,6 +1854,7 @@ def f(x): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"autograd_cache_normalize_inputs": True}) def test_split_module(self): class Mod(torch.nn.Module): def forward(self, x, a0, a1, b0, b1, c0, c1): @@ -1900,6 +1901,14 @@ def t(): y = ca0(a0, x, a1) y = ca1(b0, y, b1) y = ca2(c0, y, c1) + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) + # TODO: split_module causes ca1 and ca2 to have different type annotations + # for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) expected = Mod()(*example_inputs) self.assertEqual(y, expected) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 954dc399f96be..7c06f22905b2e 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -384,6 +384,57 @@ def _reduce_tensor(self, tensor): return (_ident, (metadata,)) +@contextlib.contextmanager +def normalize_placeholder_names(gm: torch.fx.GraphModule): + """ + Context manager that normalizes the placeholder names in the graph module. + This is used while generating a cache key for AOTAutogradCache, so that two graphs + that are isomorphic when normalizing names can hit the same cache entry. + This is safe because nothing underneath AOTAutograd uses the node names on the + original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are + in terms of original sources rather than placeholder names. + """ + # Standalone inductor: we're bypassing AOTAutogradCache anyway, so return the graph + # as-is + if not config.autograd_cache_normalize_inputs or not hasattr(gm, "graph"): + yield + return + + # Track all the old state of placeholders + old_placeholder_names = [] + old_used_names = copy(gm.graph._graph_namespace._used_names) + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + # _rename renames the node in the body of the function, + # but it doesn't change the raw name from node.target + # So we also set the raw_name of node.target to a new placeholder name + new_placeholder_name = f"p_{i}" + old_placeholder_names.append((n.name, n.target)) + n.target = new_placeholder_name + n._rename(new_placeholder_name) + i += 1 + gm.recompile() + try: + yield + finally: + # Used_names contains all our old placeholder names, + # so we clear it temporarily when we put them back + gm.graph._graph_namespace._used_names = set() + # Restore the placeholder names + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + (name, target) = old_placeholder_names[i] + n.target = target + n._rename(name) + i += 1 + assert i == len(old_placeholder_names) + # Now restore the old namespace's used names + gm.graph._graph_namespace._used_names = old_used_names + gm.recompile() + + def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, @@ -407,7 +458,6 @@ def autograd_cache_key( if triton.__version__ < "3.2.0": raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") - details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) pickler = AOTAutogradCachePickler(gm) # The prefix distinguishes among the other kinds of objects we cache @@ -924,21 +974,22 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule): and then put them back before returning. This way, we generate a cache key based off of a canonical graph without these fields, and also guarantee they aren't used to affect the cache's output. """ - IGNORED_FIELDS = ( - "meta", # metadata used by export - "compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior - "_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source - "_backend_id", - ) + # Mapping from each field to a default value + IGNORED_FIELDS: dict[str, Any] = { + "meta": {}, # metadata used by export + "compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id": None, + } saved_fields = {} - for field in IGNORED_FIELDS: + for field, default_value in IGNORED_FIELDS.items(): saved_fields[field] = getattr(gm, field, None) # Clear the field - setattr(gm, field, None) + setattr(gm, field, default_value) try: - yield + with normalize_placeholder_names(gm): + yield finally: - # Put the fields back after dispatch_and_compile is complete for field, value in saved_fields.items(): setattr(gm, field, value) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index e8778f31889dc..2833a2b1631a1 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -61,6 +61,10 @@ # need to add env vars or make it configurable bundled_autograd_cache: bool = False +# Whether or not to normalize placeholder names in graphs +# from dynaom in AOTAutogradCache +autograd_cache_normalize_inputs = not is_fbcode() + def remote_autograd_cache_default() -> Optional[bool]: if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1": From 5b10b0a96f9abf8c2751db324f0773aa433ec783 Mon Sep 17 00:00:00 2001 From: Sudarshan Raghunathan Date: Mon, 14 Jul 2025 17:55:14 +0000 Subject: [PATCH 010/457] Slightly improve error message from repeat_interleave kernel (#157996) Summary: In many investigations relating to invalid feature values, the three-argument form of `repeat_interleave` currently prints the following message if there is an inconsistency between `sum(repeats)` and `output_size`: ``` Assertion `result_size == cumsum_ptr[size - 1]` failed. ``` This is a bit hard for model authors to understand so I made the error slightly more comprehensible. After the fix the stdout contains the actual values of these parameters: https://fburl.com/mlhub/cfyyhh3q ``` Invalid input! In `repeat_interleave`, the `output_size` argument (949487) must be the same as the sum of the elements in the `repeats` tensor (949687). ``` In many cases, this is potentially useful information since we know for example that the difference between the two values above (949687-949487=200) happens to be the lengths of one of the features. ## What are my concerns with this change? 1. Outputs from `__assert_fail` go to `stderr` whereas `printf` writes to `stdout`. This is not the usual debugging flow where all logs can be found in `stderr`. I could not find a way to redirect `printf` to stderr or `__assert_fail` to stdout 2. Two checks happen instead of one in the error path. I wanted to preserve the semantics of what happens inside `__assert_fail`. 3. I have not seen this pattern in other PyTorch kernels but `repeat_interleave` with three arguments seems special in other ways too. Test Plan: * Built an ephemeral package with my changes: https://www.internalfb.com/intern/servicelab/build/736441058/ * Verified that a job with these changes indeed prints out the expected message to stdout: https://fburl.com/mlhub/jgbqk8eg * I will export to GH and run CI/CD tests. Rollback Plan: steps: - manual.note: content: >- Just reverting this diff should be sufficient. Since this change is in CUDA kernels, I do not believe there is a way to change the error message via a JK. Reviewed By: mradmila Differential Revision: D77904753 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157996 Approved by: https://github.com/ngimel, https://github.com/eqy --- aten/src/ATen/native/cuda/Repeat.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index aa351baaf9c02..1e2364ae50913 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -17,7 +17,13 @@ __global__ static void compute_cuda_kernel( index_t* result_ptr, int64_t size, int64_t result_size) { - CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]); + if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) { + printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] " + "Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n", + __FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]); + CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]) + } + int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE; int warp_id = idx / C10_WARP_SIZE; From 5633283574c458bd6a3cbb6a0a890f0cb9c8b2b5 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 14 Jul 2025 18:07:21 +0000 Subject: [PATCH 011/457] [reland][DTensor][FSDP2] necessary changes to FSDP and TP to unblock EP (#158204) This PR is identical to https://github.com/pytorch/pytorch/pull/157216, which got reverted because of removing an outdated import of `torch._dynamo` https://www.internalfb.com/diff/D78021229?transaction_fbid=1713683499308113 The issue has been fixed by @weifengpy by D78199546, so this PR should be good to re-land. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158204 Approved by: https://github.com/weifengpy --- .../test_2d_composability.py | 15 ----- .../fsdp/_fully_shard/_fsdp_param.py | 19 +++--- torch/distributed/tensor/parallel/_utils.py | 67 ------------------- torch/distributed/tensor/parallel/api.py | 2 - 4 files changed, 10 insertions(+), 93 deletions(-) delete mode 100644 torch/distributed/tensor/parallel/_utils.py diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 5ad96979717c4..3ab0b6269b2da 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -554,21 +554,6 @@ def _compare_params(self, m1, m2): p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") - @with_comms - @skip_if_lt_x_gpu(4) - def test_raise_invalid_tp_composition(self): - with self.assertRaisesRegex( - RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh" - ): - mesh_2d = init_device_mesh( - self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp") - ) - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan) - @with_comms @skip_if_lt_x_gpu(4) def test_2d_fsdp_state_enable_extension(self): diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 855a706e6d304..7649c32ec1c0e 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -292,21 +292,22 @@ def _init_sharded_param( dp_global_mesh is None or tp_global_mesh is None ): raise AssertionError( - "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" - f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" + "FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n" + f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}" ) name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" assert dp_mesh.mesh_dim_names is not None, name_dims_error assert tp_mesh.mesh_dim_names is not None, name_dims_error submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names self._spmd_mesh = dp_global_mesh[submesh_names] - if len(self._tp_spec.placements) != 1: + if len(self._tp_spec.placements) > 2: raise NotImplementedError( - f"FSDP only supports 1D TP, not {self._tp_spec.placements}" + f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" ) split_factor = self._tp_spec.num_shards_map[shard_dim] - assert 2 <= self._spmd_mesh.ndim <= 3, ( - f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + assert 2 <= self._spmd_mesh.ndim <= 4, ( + "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " + f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." ) self._spmd_placements: tuple[Placement, ...] dp_shard_tp_placement = ( @@ -315,11 +316,11 @@ def _init_sharded_param( if split_factor > 1 else fsdp_placement ), - self._tp_spec.placements[0], + *self._tp_spec.placements, ) - if self._spmd_mesh.ndim == 2: + if dp_mesh.ndim == 1: # FSDP self._spmd_placements = dp_shard_tp_placement - else: + else: # HSDP assert self.mesh_info.replicate_mesh_dim == 0 self._spmd_placements = (Replicate(),) + dp_shard_tp_placement self._sharding_spec = DTensorSpec( diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py deleted file mode 100644 index 0a78872f57d8b..0000000000000 --- a/torch/distributed/tensor/parallel/_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -# mypy: allow-untyped-defs -import warnings -from typing import Union - -from torch.distributed.device_mesh import _mesh_resources -from torch.distributed.tensor import DeviceMesh -from torch.distributed.tensor.placement_types import Placement - - -try: - from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling -except Exception: - - def is_torchdynamo_compiling(): # type: ignore[misc] - return False - - -LayoutsType = Union[Placement, tuple[Placement, ...]] - - -def _deprecate_warnings(func_name: str, extra_msg: str) -> None: - """ - Inject common validation logics for `_prepare_input` funcs via this decorator. - - Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor` - and only 1D :class:`DeviceMesh` is passed in. - """ - # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. - if not is_torchdynamo_compiling(): - warnings.warn( - f"{func_name} is deprecated and will be removed soon. {extra_msg}", - FutureWarning, - stacklevel=3, - ) - - -def _validate_tp_mesh_dim( - device_mesh: DeviceMesh, -) -> None: - """ - Check whether TP mesh dimension is valid or not. - - Args: - device_mesh (:class:`DeviceMesh`): - The `device_mesh` where we perform - Tensor Parallelism on. - - Return: - `True` if the mesh dimension - is valid, `False` otherwise. - """ - if device_mesh.ndim > 1: - raise ValueError( - f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' - ) - - root_mesh = _mesh_resources.get_root_mesh(device_mesh) - # if a root mesh is not the same as device_mesh, - # meaning the device_mesh is sliced out from the root mesh. - if root_mesh and root_mesh != device_mesh: - tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh) - if tp_mesh_dim_in_root != root_mesh.ndim - 1: - raise RuntimeError( - f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.", - "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", - ) diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 81c005000a855..2a3369a8edda0 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn from torch.distributed.device_mesh import _mesh_resources, DeviceMesh -from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim from torch.distributed.tensor.parallel.style import ParallelStyle @@ -71,7 +70,6 @@ def parallelize_module( # type: ignore[return] torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") device_mesh = device_mesh or _mesh_resources.get_current_mesh() - _validate_tp_mesh_dim(device_mesh) if parallelize_plan is None: warnings.warn( From f87d1179391d66854e3c6ca20717803cfa22f878 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Sun, 13 Jul 2025 17:58:25 -0700 Subject: [PATCH 012/457] redo of [Inductor][Cutlass] verify cutlass has cache_file attribute before moving...resolves cutlass cute exception (#158206) trying to land https://github.com/pytorch/pytorch/pull/156672 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158206 Approved by: https://github.com/lessw2020, https://github.com/Skylion007 --- torch/_inductor/codegen/cuda/cutlass_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index e2251b42fc7e9..eb479e477ea20 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -45,7 +45,10 @@ def move_cutlass_compiled_cache() -> None: else: import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401 - if not os.path.exists(python_cutlass.CACHE_FILE): + # Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists + if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists( + python_cutlass.CACHE_FILE + ): return try: From 725c3272848c408d0fa2cba4de76affe90f793b5 Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:12:41 +0000 Subject: [PATCH 013/457] [nativert] add memory overlap debug assertion (#157290) Summary: better safe than sorry. will throw if memory overlap detected when using planned tensors and debug mode is enabled -- this will make our planning unit tests more robust. Test Plan: ci Rollback Plan: Differential Revision: D77327841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157290 Approved by: https://github.com/SherlockNoMad, https://github.com/zhxchen17 --- test/cpp/nativert/CMakeLists.txt | 9 + test/cpp/nativert/test_alias_analyzer.cpp | 182 ++++++++++++++++++ torch/nativert/executor/ExecutionFrame.h | 7 +- torch/nativert/executor/Executor.cpp | 35 ++-- torch/nativert/executor/Executor.h | 18 +- torch/nativert/executor/GraphExecutorBase.cpp | 4 +- torch/nativert/executor/GraphExecutorBase.h | 2 +- .../executor/ParallelGraphExecutor.cpp | 26 +-- .../nativert/executor/ParallelGraphExecutor.h | 4 +- .../nativert/executor/SerialGraphExecutor.cpp | 8 +- .../executor/memory/AliasAnalyzer.cpp | 81 +++++++- .../nativert/executor/memory/AliasAnalyzer.h | 47 ++++- .../executor/memory/LayoutManager.cpp | 163 ++++++++++++++++ .../nativert/executor/memory/LayoutManager.h | 41 +++- .../executor/memory/LayoutPlanner.cpp | 13 +- .../nativert/executor/memory/LayoutPlanner.h | 16 +- .../kernels/AutoFunctionalizeKernel.cpp | 5 +- torch/nativert/kernels/KernelFactory.cpp | 15 +- torch/nativert/kernels/KernelFactory.h | 4 +- 19 files changed, 604 insertions(+), 76 deletions(-) create mode 100644 test/cpp/nativert/test_alias_analyzer.cpp diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 10b750d8b39ad..b6e6cd20ced7e 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp + ${TORCH_ROOT}/torch/nativert/executor/Executor.cpp + ${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp + ${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp + ${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp + ${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp + ${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp + ${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp + ${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp + ${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_alias_analyzer.cpp b/test/cpp/nativert/test_alias_analyzer.cpp new file mode 100644 index 0000000000000..afa469f58c8b2 --- /dev/null +++ b/test/cpp/nativert/test_alias_analyzer.cpp @@ -0,0 +1,182 @@ +#include + +#include + +#include +#include + +#include +#include + +using namespace ::testing; +using namespace torch::nativert; + +using AliasTestCase = std::tuple< + std::string /* value */, + AllocationLifetime, + bool /* is_alias */, + bool /* is_storage_associated_with_output */, + c10::FastSet /* source(s) */>; + +class AliasAnalyzerTests : public testing::Test { + void SetUp() override {} + + void TearDown() override { + test_cases.clear(); + model.clear(); + } + + public: + void setTestCases(std::vector cases) { + test_cases = std::move(cases); + } + + void setModel(std::string m) { + model = std::move(m); + } + + void run() { + EXPECT_FALSE(test_cases.empty()); + EXPECT_FALSE(model.empty()); + + ExecutorConfig cfg; + cfg.enableStaticCPUKernels = true; + + auto graph = stringToGraph(model); + auto kernels = KernelFactory().initializeNodeKernels( + *graph, nullptr, cfg, {}, nullptr); + auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels); + + AliasAnalyzer analyzer(*graph, kernelSchemas); + + for ( + auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] : + test_cases) { + LOG(INFO) << fmt::format( + "running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}", + value, + lifetime.start, + lifetime.end, + is_alias, + is_storage_associated_with_output, + srcs.empty() ? "{}" + : std::accumulate( + srcs.begin(), + srcs.end(), + std::string{}, + [](std::string cur, const std::string& src) { + cur.append(","); + cur.append(src); + return cur; + })); + auto* v = graph->getValue(value); + EXPECT_EQ(analyzer.lifetime(v), lifetime); + EXPECT_EQ(analyzer.is_alias(v), is_alias); + EXPECT_EQ( + analyzer.is_storage_associated_with_output(v), + is_storage_associated_with_output); + const auto* resolved_srcs = analyzer.get_sources_of_alias(v); + if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) { + EXPECT_FALSE(srcs.empty()); + EXPECT_EQ(resolved_srcs->size(), srcs.size()); + for (const auto& resolved_src : *resolved_srcs) { + EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1); + } + EXPECT_TRUE(srcs.empty()); + } else { + EXPECT_TRUE(srcs.empty()); + } + } + } + + private: + std::string model; + std::vector test_cases; +}; + +TEST_F(AliasAnalyzerTests, TestNoAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %res = torch.ops.aten.clone.default(self=%out_t, memory_format=None) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 2), false, false, {}}, + {"res", AllocationLifetime(2, 3), false, true, {}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestSimpleAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 3), false, true, {}}, + {"res", AllocationLifetime(2, 3), true, false, {"out_t"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestDeepAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1) + %res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 4), false, true, {}}, + {"a1", AllocationLifetime(2, 4), true, false, {"out_t"}}, + {"res", AllocationLifetime(3, 4), true, false, {"out_t"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestPackedListUnpack) { + setModel(R"( + graph(%a, %b, %c, %d): + %input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d) + %x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list) + return (%x1, %x3))"); + + setTestCases({ + {"a", AllocationLifetime(0, 2), false, false, {}}, + {"x0", AllocationLifetime(2, 2), true, false, {"a"}}, + {"b", AllocationLifetime(0, 3), false, true, {}}, + {"x1", AllocationLifetime(2, 3), true, false, {"b"}}, + {"c", AllocationLifetime(0, 2), false, false, {}}, + {"x2", AllocationLifetime(2, 2), true, false, {"c"}}, + {"d", AllocationLifetime(0, 3), false, true, {}}, + {"x3", AllocationLifetime(2, 3), true, false, {"d"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %a1 = prim.VarStack(l0=%out_t, l1=%out_t2) + %res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 5), false, true, {}}, + {"out_t2", AllocationLifetime(2, 5), false, true, {}}, + {"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}}, + {"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}}, + }); + + run(); +} diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h index ae8821a6e58b0..945c3b0c5036d 100644 --- a/torch/nativert/executor/ExecutionFrame.h +++ b/torch/nativert/executor/ExecutionFrame.h @@ -46,13 +46,14 @@ class ExecutionFrame { } template - auto withMemoryPlanner(CB&& cb) { + auto withManagedMemory(CB&& cb) { if (!layoutManager_) { - return std::forward(cb)(); + return std::forward(cb)(nullptr); } LayoutManagerGuard guard(*layoutManager_); - return std::forward(cb)(); + return std::forward(cb)( + const_cast(layoutManager_.get())); } std::vector tryMoveUserOutputs(); diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 285b6dea00dd7..eb25342f65df2 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -19,30 +19,31 @@ namespace torch::nativert { Executor::Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, - std::shared_ptr weights, - const Placement& placement, - std::shared_ptr pytorchStreamReader, - const MakeProxyExecutorFn& makeProxyExecutorFunc) + const std::shared_ptr& weights, + Placement placement, + const std::shared_ptr& + pytorchStreamReader, + MakeProxyExecutorFn makeProxyExecutorFunc) : executorConfig_(std::move(executorConfig)), graph_(std::move(graph)), - placement_(placement), + placement_(std::move(placement)), constantFolder_( executorConfig_.runConstFolding ? std::optional(*graph_) : std::nullopt), - makeProxyExecutorFunc_(makeProxyExecutorFunc), + makeProxyExecutorFunc_(std::move(makeProxyExecutorFunc)), executionFrames_(executorConfig_.maxNumConcurrentThreads), clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), numExecutionFrames_(0), lastClearedTimestamp_(getCurrentTimestampSeconds()) { if (weights) { - initialize(std::move(weights), std::move(pytorchStreamReader)); + initialize(weights, pytorchStreamReader); } } void Executor::initialize( - std::shared_ptr weights, - std::shared_ptr + const std::shared_ptr& weights, + const std::shared_ptr& pytorchStreamReader) { auto start = std::chrono::high_resolution_clock::now(); @@ -51,7 +52,7 @@ void Executor::initialize( weights, executorConfig_, placement_, - std::move(pytorchStreamReader), + pytorchStreamReader, makeProxyExecutorFunc_); if (constantFolder_.has_value()) { @@ -113,13 +114,14 @@ void Executor::atomicSwapWeights(std::shared_ptr weights) { } } -void Executor::maybeRunConstantFolding(std::shared_ptr weights) { +void Executor::maybeRunConstantFolding( + const std::shared_ptr& weights) { for (auto& execution : constFoldingExecutions_) { ExecutionFrame constFoldingFrame(execution.executor->graph()); std::vector inputs; inputs.reserve(graph_->signature().inputsToWeights().size()); for (const auto& [_, name] : graph_->signature().inputsToWeights()) { - inputs.push_back(weights->at(name)); + inputs.emplace_back(weights->at(name)); } auto outputs = execution.executor->execute(constFoldingFrame, inputs); @@ -130,7 +132,7 @@ void Executor::maybeRunConstantFolding(std::shared_ptr weights) { } } -void Executor::processWeights(std::shared_ptr weights) { +void Executor::processWeights(const std::shared_ptr& weights) { maybeRunConstantFolding(weights); if (constantFolder_.has_value()) { constantFolder_->evaluate(*weights); @@ -352,10 +354,10 @@ std::vector Executor::execute( } ProfileMetrics Executor::benchmarkIndividualNodes( - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns) { - CHECK(inputsList.size() > 0) << "Need at least one input to benchmark"; + CHECK(!inputsList.empty()) << "Need at least one input to benchmark"; CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run"; for (const auto& inputs : inputsList) { @@ -378,8 +380,9 @@ int64_t Executor::getCurrentTimestampSeconds() const { std::vector Executor::getDelegates() { std::vector delegates; + delegates.reserve(delegateExecutors_.size()); for (const auto& delegateExecutor : delegateExecutors_) { - delegates.push_back(delegateExecutor.get()); + delegates.emplace_back(delegateExecutor.get()); } return delegates; } diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index db496ace926ee..3ab206b01e0c1 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -79,11 +79,11 @@ class Executor { Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, - std::shared_ptr weights, - const Placement& placement = Placement(), - std::shared_ptr + const std::shared_ptr& weights, + Placement placement = Placement(), + const std::shared_ptr& pytorchStreamReader = nullptr, - const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + MakeProxyExecutorFn makeProxyExecutorFunc = nullptr); std::shared_ptr getWeights() { std::shared_ptr ret; @@ -91,7 +91,7 @@ class Executor { return ret; } - void processWeights(std::shared_ptr weights); + void processWeights(const std::shared_ptr& weights); void atomicSwapWeights(std::shared_ptr weights); // This API only returns the flattened UserOutputs, @@ -106,7 +106,7 @@ class Executor { const ITreeSpec& inputTreeSpec); ProfileMetrics benchmarkIndividualNodes( - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns); @@ -141,8 +141,8 @@ class Executor { c10::Synchronized> weights_; void initialize( - std::shared_ptr weights, - std::shared_ptr + const std::shared_ptr& weights, + const std::shared_ptr& pytorchStreamReader); ExecutorFramePtr getExecutorFrameFromPool(); @@ -171,7 +171,7 @@ class Executor { ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete; }; - void maybeRunConstantFolding(std::shared_ptr weights); + void maybeRunConstantFolding(const std::shared_ptr& weights); void validateInputs(const std::vector& inputs) const; // Helper method to get current timestamp in seconds diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index 5ad31a7dacabe..7796575aad291 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -32,7 +32,7 @@ void GraphExecutorBase::fillUserInputs( ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ExecutionFrame& executionFrame, - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns) { // TODO: add support for memory profiling @@ -112,7 +112,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( results.totalNodesCount = numNodes; for (const auto& r : results.timePerNodeType) { const std::string& target = r.first; - results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; + results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime; } return results; } diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h index 86c6ed61c1f9a..8d659f1588c2b 100644 --- a/torch/nativert/executor/GraphExecutorBase.h +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -51,7 +51,7 @@ class GraphExecutorBase { ProfileMetrics benchmarkIndividualNodes( ExecutionFrame& executionFrame, - std::vector> inputs, + const std::vector>& inputs, const uint32_t warmup_runs, const uint32_t main_runs); diff --git a/torch/nativert/executor/ParallelGraphExecutor.cpp b/torch/nativert/executor/ParallelGraphExecutor.cpp index c147d23873d3d..b54b22228f977 100644 --- a/torch/nativert/executor/ParallelGraphExecutor.cpp +++ b/torch/nativert/executor/ParallelGraphExecutor.cpp @@ -22,11 +22,13 @@ ThreadPoolExecutor::~ThreadPoolExecutor() { } C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() { + // NOLINTNEXTLINE(misc-use-internal-linkage) thread_local moodycamel::ProducerToken ptok(*work_); return ptok; } C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() { + // NOLINTNEXTLINE(misc-use-internal-linkage) thread_local moodycamel::ConsumerToken ctok(*work_); return ctok; } @@ -39,7 +41,7 @@ void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) { void ThreadPoolExecutor::start(int32_t numThreads) { stopped_ = false; for (int32_t i = 0; i < numThreads; ++i) { - threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this)); + threads_.emplace_back(&ThreadPoolExecutor::loop, this); } } @@ -62,16 +64,17 @@ void ThreadPoolExecutor::loop() { void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) { session->addWork(); - work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session)); + work_->enqueue(ptok(), [unit, this, session] { unit->run(this, session); }); sem_->release(); } void ThreadPoolExecutor::add( SessionState* session, - std::vector::const_iterator&& begin, - const std::vector::const_iterator&& end) { + std::vector::const_iterator begin, + const std::vector::const_iterator& end) { const auto count = end - begin; + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) switch (count) { case 0: { return; @@ -86,16 +89,17 @@ void ThreadPoolExecutor::add( std::vector runnables; runnables.reserve(count); for (; begin != end; ++begin) { - runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session)); + runnables.emplace_back( + [capture0 = *begin, this, session] { capture0->run(this, session); }); } work_->enqueue_bulk(ptok(), runnables.begin(), count); - sem_->release(count); + sem_->release(static_cast(count)); } void ThreadPoolExecutor::stop() { stopped_ = true; - sem_->release(threads_.size()); + sem_->release(static_cast(threads_.size())); std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); }); threads_.clear(); @@ -136,10 +140,10 @@ void ThreadPoolExecutor::run( } void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) { - thread_local std::vector newWorkUnits; - thread_local c10::InferenceMode mode; + /* thread_local */ std::vector newWorkUnits; + /* thread_local */ c10::InferenceMode mode; - WorkUnit* unit = this; + /* thread_local */ WorkUnit* unit = this; while (true) { unit->kernel->compute(session->frame()); @@ -219,7 +223,7 @@ ParallelGraphExecutor::ParallelGraphExecutor( } } - executor_.start(executorConfig.maxParallelOps); + executor_.start(static_cast(executorConfig.maxParallelOps)); } std::vector ParallelGraphExecutor::execute( diff --git a/torch/nativert/executor/ParallelGraphExecutor.h b/torch/nativert/executor/ParallelGraphExecutor.h index 747e6993770ac..1810ffb3b7b14 100644 --- a/torch/nativert/executor/ParallelGraphExecutor.h +++ b/torch/nativert/executor/ParallelGraphExecutor.h @@ -46,8 +46,8 @@ class ThreadPoolExecutor { void add(SessionState* session, WorkUnit* unit); void add( SessionState* session, - std::vector::const_iterator&& begin, - const std::vector::const_iterator&& end); + std::vector::const_iterator begin, + const std::vector::const_iterator& end); C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok(); C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok(); diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp index 017f4f178c8b5..58a7cd1c4307c 100644 --- a/torch/nativert/executor/SerialGraphExecutor.cpp +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -14,11 +14,17 @@ std::vector SerialGraphExecutor::execute( std::vector SerialGraphExecutor::executeWithPrefilledFrame( ExecutionFrame& executionFrame) { - executionFrame.withMemoryPlanner([&]() { + executionFrame.withManagedMemory([&](const LayoutManager* layout_manager) { // Execute kernels for all nodes except prim.Input and prim.Output for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { nodeKernels_[nodeIdx]->compute(executionFrame); +#ifndef NDEBUG + if (layout_manager != nullptr) { + layout_manager->assert_no_overlapping_storages(nodeIdx); + } +#endif + // don't free intermediate values when static memory planning is enabled if (executorConfig_.tryFreeUnmanagedValuesAfterUse) { // Free the intermediate values that are no used anymore diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp index e56eb40853169..0bef32545d14b 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.cpp +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -23,18 +23,32 @@ AliasAnalyzer::AliasAnalyzer( maybe_update_aliases_from_schema(node, schemas); } + maybe_extend_lifetimes(graph); + + // squash_deep_aliases this will populate aliases_ + // with a mapping from each alias to its backed + // source (i.e., the value that owns the underlying + // dataptr for said alias) + squash_deep_aliases(graph); + // set all non-aliasing outputs. outputs // that are aliased will be set later when // lifetimes are extended for (const auto* output : graph.outputs()) { if (!is_alias(output)) { - values_associated_with_outputs_.insert(output); + values_associated_with_outputs_.emplace(output); } } - maybe_extend_lifetimes(graph); log_state(); -} + + alive_values_at_time_.resize(graph.nodes().size()); + for (const auto& [v, lifetime] : lifetimes_) { + for (const auto t : c10::irange(lifetime.start, lifetime.end + 1)) { + alive_values_at_time_[t].emplace_back(v); + } + } +} // namespace torch::nativert bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( const Node& node, @@ -63,7 +77,7 @@ bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( create_or_update_lifetime(input, i); create_or_update_lifetime(output, i); - aliases_[output].insert(input); + aliases_[output].emplace(input); } return true; @@ -96,7 +110,7 @@ void AliasAnalyzer::maybe_update_aliases_from_schema( VLOG(1) << node.target() << " may contain input/output alias: " << input->id() << " -> " << output->id(); - aliases_[output].insert(input); + aliases_[output].emplace(input); } } } @@ -109,6 +123,56 @@ void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) { } } +void AliasAnalyzer::squash_deep_aliases(const Graph& graph) { + for (auto& node : graph.nodes()) { + for (const auto& output : node.outputs()) { + auto aliasIt = aliases_.find(output); + if (aliasIt == aliases_.end()) { + continue; + } + + c10::FastSet filtered_srcs; + + auto& srcs = aliasIt->second; + for (const auto* src : srcs) { + // check if this source is an alias itself, + // making 'output' a deep alias (i.e., + // an alias of an alias) + + // we want aliases_[x] to return the value from which x + // inherits its dataptr. + // as such, we want to add values that do not meet this + // criteria (i.e., those that are aliases). + // in practice, there can only be 1 value that meets this + // criteria (at a time), but there are some cases where + // this is ambiguous (e.g., where the spec doesn't exist, + // dealing with variadics) + auto srcAliasIt = aliases_.find(src); + if (srcAliasIt == aliases_.end()) { + filtered_srcs.emplace(src); + continue; + } + + // since we are going from the beginning of the graph + // to the end of the graph we can assume that these + // aliases, which have already been visited, have already + // been squashed. + auto& srcs_of_src = srcAliasIt->second; + for (const auto* src_of_src : srcs_of_src) { + // if the source of the source is not an alias + // (i.e., it has ownership over it's data ptr) + // then we want to add it as a source of 'output' + if (aliases_.find(src_of_src) == aliases_.end()) { + filtered_srcs.emplace(src_of_src); + } + } + } + + srcs = std::move(filtered_srcs); + } + } +} + void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { c10::FastSet extended; @@ -129,10 +193,11 @@ void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { VLOG(1) << "extended EOL of value " << src->id() << " to " << eol; - extended.insert(src); + extended.emplace(src); - if (eol == graph.nodes().size() - 1 /* aliases output */) { - values_associated_with_outputs_.insert(src); + if (aliases_.find(src) == aliases_.end() && + eol == graph.nodes().size() - 1 /* aliases output */) { + values_associated_with_outputs_.emplace(src); } } } diff --git a/torch/nativert/executor/memory/AliasAnalyzer.h b/torch/nativert/executor/memory/AliasAnalyzer.h index c9784d5d84ab9..4fd3b1261b3d7 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.h +++ b/torch/nativert/executor/memory/AliasAnalyzer.h @@ -14,26 +14,38 @@ class AliasAnalyzer { const Graph& graph, const c10::FastMap& schemas); - C10_ALWAYS_INLINE const AllocationLifetime& lifetime( + const c10::FastSet* get_sources_of_alias( const Value* value) const { + const auto it = aliases_.find(value); + if (it == aliases_.end()) { + return nullptr; + } + return &it->second; + } + + const AllocationLifetime& lifetime(const Value* value) const { return lifetimes_.at(value); } - C10_ALWAYS_INLINE bool is_alias(const Value* value) const { + bool is_alias(const Value* value) const { return aliases_.find(value) != aliases_.end(); } - C10_ALWAYS_INLINE bool is_storage_associated_with_output( - const Value* value) const { + bool is_storage_associated_with_output(const Value* value) const { return values_associated_with_outputs_.find(value) != values_associated_with_outputs_.end(); } - C10_ALWAYS_INLINE const c10::FastSet& - values_associated_with_output_storage() const { + const c10::FastSet& values_associated_with_output_storage() + const { return values_associated_with_outputs_; } + const std::vector& alive_values_at_time(size_t time) const { + TORCH_CHECK_LT(time, alive_values_at_time_.size()); + return alive_values_at_time_[time]; + } + private: // listunpack operations who take a list that has // been created with a listpack operation should @@ -72,14 +84,35 @@ class AliasAnalyzer { // even if they aren't explicitly considered outputs) void maybe_extend_lifetimes(const Graph& graph); + // in the event that we have aliases-of-aliases + // we want to make sure that the 'sources' + // are propagated + // + // e.g., + // %x0 = ... + // %x1 = some_aliasing_op(x0) + // %x2 = some_aliasing_op(x1) + // + // we want aliases_[x2] = x0 + // instead of aliases[x2] = x1 + // + // the result is aliases_ will contain a + // mapping from each alias to its backed + // source (i.e., the value that owns its + // associated dataptr) + void squash_deep_aliases(const Graph& graph); + void log_state() const; - // mapping from alias to the set of values that it aliases + // mapping from alias to its source c10::FastMap> aliases_; c10::FastMap lifetimes_; // non-aliasing outputs or non-aliasing intermediates that are aliased by // outputs c10::FastSet values_associated_with_outputs_; + // alive_values_at_time_[i] = values that are "alive" during the + // computation of node i + std::vector> alive_values_at_time_; }; } // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index 7b5062d7993ff..acfe990c38bd1 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -4,6 +4,7 @@ #include #include +#include namespace torch::nativert { @@ -147,6 +148,9 @@ void LayoutManager::populate_tensor_values() { planned_tensors_max_nbytes_local_.resize(value_ids.size()); for (const auto&& [i, v] : c10::enumerate(value_ids)) { +#ifndef NDEBUG + value_to_vector_idx_map_[v] = i; +#endif planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor(); } @@ -157,6 +161,165 @@ void LayoutManager::populate_tensor_values() { } } +#ifndef NDEBUG +void LayoutManager::assert_no_overlapping_storages( + size_t graph_node_idx) const { + if (state_ != LayoutManagerState::Running) { + return; + } + + /* + for each value + (either an input or output) + ensure that the associated storage + slice lies within the allocated slice + if it is managed (or if it is an alias, + we can use the slice allocated to its source) + --- + also ensure that the current index lies + within the lifetime of this value + */ + + const auto& alias_analyzer = planner_.get_alias_analyzer(); + // get the 'active' values during the execution of nodes[graph_node_idx] + const auto& alive_values = + alias_analyzer.alive_values_at_time(graph_node_idx); + + // make sure active memory intervals are non-overlapping + // by sorting them by start, and ensuring + // cur.start > prev.end for each + // + // by default, the pairs are compared lexicographically. + // ref: https://cplusplus.com/reference/utility/pair/operators/ + // + // in our case, this means that leftmost (on the number line) intervals will + // come first, and if the start point of two intervals is the same, they will + // be sorted by their relative widths (in increasing order) + // + // e.g., the ordering for the following usage intervals + // + // |######1######| + // |######2######| + // |######3#####| + // + // would be [1,3,2] + + std::multiset> intervals; + + planner_.with_plan([&](const LayoutPlan& plan) { + // prevent recomputation from occuring + c10::FastSet checked_values; + + // check that some arbitrary storage (defined by the allocation start and + // the size in bytes) lies within the slice allocated for value_id during + // planning. + // + // if the checks pass, add the interval [alloc_start, alloc_start + + // alloc_nbytes) to the set of intervals + auto check_allocation_bounds = + [&](ValueId value_id, size_t alloc_start, size_t alloc_end) -> void { + if (!checked_values.emplace(value_id).second /* already checked */) { + return; + } + auto& alloc = plan.allocations[value_to_vector_idx_map_.at(value_id)]; + TORCH_CHECK_GE(alloc_start, alloc.offset); + TORCH_CHECK_LT(alloc_end, alloc.offset + alloc.size); + intervals.emplace(alloc_start, alloc_end); + }; + + // get the inclusive storage interval for some value (i.e., + // [buffer_storage_start_offset, buffer_storage_start_offset + + // storage_nbytes]) that represents the sub-slice of the runtime-managed + // buffer allocated to this tensor + auto try_get_interval = + [&](ValueId value_id) -> std::optional> { + const auto& iv = parent_frame_.getIValue(value_id); + if (!iv.isTensor()) { + return std::nullopt; + } + + const auto& storage_impl = iv.toTensor().storage().unsafeGetStorageImpl(); + const auto storage_nbytes = storage_impl->nbytes(); + + if (const auto start = layout_buffer_.get_offset_from_ptr( + storage_impl->data_ptr().get()); + start.has_value()) { + return std::make_pair(*start, *start + storage_nbytes - 1); + } + + return std::nullopt; + }; + + for (auto v : alive_values) { + // sanity check lifetimes to ensure this + // value ~should~ be alive at this point + const auto& lt = alias_analyzer.lifetime(v); + TORCH_CHECK_GE(graph_node_idx, lt.start); + TORCH_CHECK_LE(graph_node_idx, lt.end); + + const auto interval = try_get_interval(v->id()); + if (C10_UNLIKELY(!interval.has_value())) { + continue; + } + + auto& [v_start, v_end] = *interval; + + // it's possible that v is an alias, in which case + // we want to try to get the source (i.e., the value) + // that actually owns the storage + // + // NOTE: it's possible the source is ambiguous, hence + // why get_sources_of_alias returns a set (although it's usually a + // singleton set) + if (const auto* srcs_of_v = alias_analyzer.get_sources_of_alias(v); + srcs_of_v != nullptr /* v is an alias */) { + // 1. v's interval is a sub-interval of ~a~ source's interval and we + // want to add the source's interval to the set of intervals + // 2. v possibly got re-alloc'd / is not actually aliasing anything + // and we want to add v's interval to the set of intervals + bool found_viable_source = false; + + for (const auto* src_of_v : *srcs_of_v) { + const auto src_interval = try_get_interval(src_of_v->id()); + if (C10_UNLIKELY(!src_interval.has_value())) { + continue; + } + + auto& [src_of_v_start, src_of_v_end] = *src_interval; + + if (v_start >= src_of_v_start && v_end <= src_of_v_end) { + check_allocation_bounds( + src_of_v->id(), src_of_v_start, src_of_v_end); + found_viable_source = true; + break; + } + } + + if (!found_viable_source) { + check_allocation_bounds(v->id(), v_start, v_end); + } + } else /* if v isn't an alias */ { + check_allocation_bounds(v->id(), v_start, v_end); + } + } + }); + + // if we only have less than two active intervals, + // it isn't possible to have overlap... + if (intervals.size() < 2) { + return; + } + + // ensure that no 'active' buffer intervals are overlapping + auto it = intervals.begin(); + size_t prev_end = it->second; + while (++it != intervals.end()) { + TORCH_CHECK_LT(prev_end, it->first /* cur_start */); + prev_end = it->second; + } +} +#endif + void LayoutManager::try_update_historical_max_nbytes() { for (const auto i : c10::irange(planned_tensors_.size())) { auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes()); diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h index 76f658e09d08b..aa1b1c4630cc0 100644 --- a/torch/nativert/executor/memory/LayoutManager.h +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -24,6 +24,20 @@ struct ContiguousLayoutBuffer { ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) = delete; + std::optional get_offset_from_ptr(void* offset_ptr) const { + void* raw_ptr = data_ptr_.get(); + if (!raw_ptr || !offset_ptr) { + return std::nullopt; + } + + auto offset = reinterpret_cast(offset_ptr) - + reinterpret_cast(raw_ptr); + + return offset < 0 || static_cast(offset) >= size_ + ? std::nullopt + : std::optional(offset); + } + void* get_ptr_with_offset(size_t offset) { void* raw_ptr = data_ptr_.get(); TORCH_CHECK_NOTNULL(raw_ptr); @@ -148,10 +162,32 @@ class LayoutManager { torch::nativert::LayoutManagerSettings settings = {}); ~LayoutManager() = default; +// this is a debugging function. it will slow thing down SIGNIFICANTLY +// so please ensure this isn't called unless you really need it +// +// it checks a few things in between node executions... +// +// 1. ensures all 'alive' values are within the bounds of thier lifetimes +// - this is the definition of a sanity check since the live-sets are built +// from the lifetimes lol. if this fails, something is very very wrong +// 2. ensures that all planned values are within the bounds of their +// allocated storage buffer slices +// - if the value is an alias, ensure the alias is within the bounds +// of the source value +// 3. ensures that all planned value data-ptrs are non-overlapping +#ifndef NDEBUG + void assert_no_overlapping_storages( + size_t + graph_node_idx /* the graph node that is currently being computed */) + const; +#endif + + private: + friend class LayoutManagerGuard; + void allocate(); void deallocate_and_plan(); - private: #ifdef LayoutPlannerTests_TEST_FRIENDS LayoutPlannerTests_TEST_FRIENDS; #endif @@ -178,6 +214,9 @@ class LayoutManager { std::vector planned_tensors_; std::vector planned_tensors_max_nbytes_local_; +#ifndef NDEBUG + c10::FastMap value_to_vector_idx_map_; +#endif ContiguousLayoutBuffer layout_buffer_; ContiguousStorageImplBuffer storage_impl_buffer_; diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp index 5c45a08ea6f14..5fb0b8fced6f7 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.cpp +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -16,9 +16,18 @@ LayoutPlanner::LayoutPlanner( const c10::FastMap& kernelSchemas, const std::vector& persistentValues, const torch::nativert::LayoutPlannerSettings& settings) - : managed_values_(graph.values().size()), settings_(settings) { - auto value_to_allocation_spec = c10::FastMap{}; + : managed_values_(graph.values().size()), +#ifndef NDEBUG + alias_analyzer_(graph, kernelSchemas), +#endif + settings_(settings) { +#ifndef NDEBUG + auto& alias_analyzer = alias_analyzer_; +#else auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas); +#endif + + auto value_to_allocation_spec = c10::FastMap{}; std::set input_values_set_; for (const auto* nv : graph.userInputs()) { diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h index 6382fdbba01b5..83a2386c6dacf 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.h +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -61,7 +62,17 @@ class LayoutPlanner { const std::vector& get_planned_values() const; const std::vector& get_unplanned_values() const; - C10_ALWAYS_INLINE bool is_managed(ValueId id) { +#ifndef NDEBUG + const AliasAnalyzer& get_alias_analyzer() const { + return alias_analyzer_; + } +#endif + + size_t num_values() const { + return managed_values_.size(); + } + + bool is_managed(ValueId id) { TORCH_CHECK_LT(static_cast(id), managed_values_.size()); return managed_values_[id]; } @@ -120,6 +131,9 @@ class LayoutPlanner { LayoutPlannerAlgorithm* algorithm_; c10::LeftRight plan_; +#ifndef NDEBUG + AliasAnalyzer alias_analyzer_; +#endif torch::nativert::LayoutPlannerSettings settings_; }; diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp index cbbd502d82152..76589b52c56ee 100644 --- a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp @@ -11,15 +11,14 @@ UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node) op_(getOperatorForTarget( std::get(node->attributes()[0].value))), schema_(op_.schema()), - arguments_(prefillStackWithStaticArgs(node, schema_)) { + arguments_(prefillStackWithStaticArgs(node, schema_)), + numOutputs_(static_cast(schema_.returns().size())) { for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) { if (schemaArg.alias_info() != nullptr && schemaArg.alias_info()->isWrite()) { mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value); } } - - numOutputs_ = schema_.returns().size(); } void UnsafeAutoFunctionalizeKernel::computeInternal( diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 1f72fef810d6c..0720c28a7b6a5 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -62,7 +62,7 @@ c10::Device inferTargetDevice( } // namespace -inline constexpr std::string_view kSymIntOps[] = { +inline constexpr std::array kSymIntOps = { "_operator.floordiv", "_operator.mod", "torch.sym_int", @@ -72,7 +72,7 @@ inline constexpr std::string_view kSymIntOps[] = { "torch.sym_min", }; -inline constexpr std::string_view kSymBoolOps[] = { +inline constexpr std::array kSymBoolOps = { "_operator.eq", "_operator.ne", "_operator.le", @@ -83,14 +83,14 @@ inline constexpr std::string_view kSymBoolOps[] = { "torch.sym_not", }; -inline constexpr std::string_view kSymFloatOps[] = { +inline constexpr std::array kSymFloatOps = { "torch._sym_sqrt", "math.trunc", "_operator.neg", "_operator.truediv", }; -inline constexpr std::string_view kScalarBinaryOps[] = { +inline constexpr std::array kScalarBinaryOps = { "_operator.mul", "_operator.add", "_operator.sub", @@ -124,10 +124,11 @@ void KernelFactory::registerHandler( ExecutionKernels KernelFactory::initializeNodeKernels( const Graph& graph, - std::shared_ptr weights, + const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, - std::shared_ptr pytorchStreamReader, + const std::shared_ptr& + pytorchStreamReader, const MakeProxyExecutorFn& makeProxyExecutorFunc) { std::vector> nodeKernels; std::vector> delegateExecutors; @@ -216,7 +217,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels( *subgraph, weights, executorConfig, placement); CHECK(executionKernels.delegateExecutors.empty()) << "HigherOrderKernel does not support delegates"; - CHECK(executionKernels.constFoldingExecutions.size() == 0) + CHECK(executionKernels.constFoldingExecutions.empty()) << "HigherOrderKernel does not support const folding"; if (executorConfig.maxParallelOps > 1) { graphExecutors.emplace_back( diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index c01d64c3a0178..8c5d5fc661d1d 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -74,10 +74,10 @@ class KernelFactory { ExecutionKernels initializeNodeKernels( const Graph& graph, - std::shared_ptr weights, + const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, - std::shared_ptr + const std::shared_ptr& pytorchStreamReader = nullptr, const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); From 1e4d8b5a4a67220473bf0027c58baaa08a036714 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 14 Jul 2025 20:55:13 +0000 Subject: [PATCH 014/457] Fix land race typos from #157290 (#158272) TSIA, this is a new grammar linter being added recently. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158272 Approved by: https://github.com/clee2000 --- torch/nativert/executor/memory/LayoutManager.cpp | 2 +- torch/nativert/executor/memory/LayoutManager.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index acfe990c38bd1..a75070095caf7 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -207,7 +207,7 @@ void LayoutManager::assert_no_overlapping_storages( std::multiset> intervals; planner_.with_plan([&](const LayoutPlan& plan) { - // prevent recomputation from occuring + // prevent recomputation from occurring c10::FastSet checked_values; // check that some arbitrary storage (defined by the allocation start and diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h index aa1b1c4630cc0..347c51fe2edec 100644 --- a/torch/nativert/executor/memory/LayoutManager.h +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -167,7 +167,7 @@ class LayoutManager { // // it checks a few things in between node executions... // -// 1. ensures all 'alive' values are within the bounds of thier lifetimes +// 1. ensures all 'alive' values are within the bounds of their lifetimes // - this is the definition of a sanity check since the live-sets are built // from the lifetimes lol. if this fails, something is very very wrong // 2. ensures that all planned values are within the bounds of their From 6b2bef10afae4acb18f230a496392b673c954ce7 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 14 Jul 2025 10:00:40 -0700 Subject: [PATCH 015/457] [c10d] Prototype of `group_split` for dist2 work (#157716) This is to implement group_split as proposed in [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157716 Approved by: https://github.com/d4l3k --- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 8 +-- test/distributed/test_dist2.py | 11 ++++ torch/_C/_distributed_c10d.pyi | 7 +++ torch/csrc/distributed/c10d/Backend.hpp | 19 ++++++ torch/csrc/distributed/c10d/NCCLUtils.cpp | 21 +++++++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 1 + torch/csrc/distributed/c10d/ProcessGroup.cpp | 58 +++++++++++++++++++ torch/csrc/distributed/c10d/ProcessGroup.hpp | 13 +++++ .../distributed/c10d/ProcessGroupGloo.cpp | 29 ++++++++++ .../distributed/c10d/ProcessGroupGloo.hpp | 13 ++++- .../distributed/c10d/ProcessGroupNCCL.cpp | 39 +++++++++++++ .../distributed/c10d/ProcessGroupNCCL.hpp | 11 +++- .../csrc/distributed/c10d/PyProcessGroup.hpp | 15 +++++ torch/csrc/distributed/c10d/init.cpp | 8 +++ 14 files changed, 246 insertions(+), 7 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 56f67035a5fb1..a1360c8dd40fd 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -28,7 +28,7 @@ class NCCLTestBase { NCCLTestBase(NCCLTestBase&& other) noexcept = default; - std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { return pg_; } @@ -39,7 +39,7 @@ class NCCLTestBase { void initialize( int rank, size_t size, - std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = + std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from = std::nullopt) { store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); @@ -52,13 +52,13 @@ class NCCLTestBase { opts->split_color = ++color_; } #endif - pg_ = std::make_unique<::c10d::ProcessGroupNCCL>( + pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>( store_, rank, size, std::move(opts)); } protected: std::string path_; - std::shared_ptr<::c10d::ProcessGroupNCCL> pg_; + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_; std::chrono::milliseconds pgTimeout_; ::c10::intrusive_ptr<::c10d::Store> store_; int color_{1}; diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index 52ffd34e2a48e..d5e925b4b2d0c 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -201,6 +201,17 @@ def test_alltoall_base(self) -> None: out_range = out[i * 10 : (i + 1) * 10] self.assertEqual(out_range, torch.full_like(out_range, i + 1)) + def test_group_split(self) -> None: + group = self.new_group() + subgroup = group.split_group([0], timeout=timedelta(seconds=30)) + if self.rank == 0: + assert subgroup is not None + self.assertEqual(subgroup.size(), 1) + backend = subgroup._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=30)) + else: + self.assertEqual(subgroup, None) + class ProcessGroupGlooTest(Dist2MultiProcessTestCase): device = torch.device("cpu") diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 2efe44c86b555..f57bcb3472cc4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -350,6 +350,13 @@ class ProcessGroup: ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... + def split_group( + self, + new_ranks: list[int], + timeout: Optional[timedelta] = None, + pg_options: Optional[Backend.Options] = None, + group_desc: Optional[str] = None, + ) -> Optional[ProcessGroup]: ... def abort(self) -> None: ... def set_timeout(self, timeout: timedelta) -> None: ... def shutdown(self) -> None: ... diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index acece3d8c718b..0f1c5116803f2 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { // backend name // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string backend; + std::string group_name; }; explicit Backend(int rank, int size); @@ -105,6 +106,14 @@ class TORCH_API Backend : public torch::CustomClassHolder { TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); } + // Subclasses must override this method to return the backend name + virtual c10::intrusive_ptr getBackendOptions() { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not implement endCoalescing")); + } + virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { @@ -379,6 +388,16 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of enableCollectivesTiming."); } + virtual c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + TORCH_CHECK( + false, + "Backend ", + getBackendName(), + " is missing implementation of splitBackend."); + } + bool hasHooks() const { return onCompletionHook_ != nullptr; } diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 60bb0f2d879e1..8074cc98a04f1 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -573,6 +573,27 @@ size_t hashTensors(const std::vector& tensors) { return hash; } +// NCCL uses Non-negative int to represent in-group according to API +// requirement. We take a list of ranks and generate a hash value based on the +// list and ensure its range of 32-bit int. +int genNcclSplitColor(const std::vector& ranks) { + // Combine the hash values using a simple reducer (std::hash + fold) + std::size_t combined_hash = std::accumulate( + ranks.begin(), + ranks.end(), + std::size_t(0), + [](std::size_t acc, int rank) { + return acc ^ + (std::hash{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2)); + }); + + // max positive value of int32_t + constexpr int32_t max_c_int = std::numeric_limits::max(); + int color = static_cast( + std::abs(static_cast(combined_hash)) % max_c_int); + return color; +} + // Default value: 30 minutes int nccl_nonblocking_timeout() { static int timeout = -2; // -2 means not initialized diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 5e61837c2353b..fcd55b6a655ef 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -231,6 +231,7 @@ static std::map ncclDataType = { }; TORCH_API size_t hashTensors(const std::vector& tensors); +TORCH_API int genNcclSplitColor(const std::vector& ranks); TORCH_API std::string getNcclVersion(); TORCH_API std::tuple getNcclVersionTuple(); TORCH_API int getNcclVersionNumber(); diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 83418d17acdcb..197fd9014b3a9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -158,6 +159,63 @@ void ProcessGroup::release_resources() { backendTypeToBackend_.clear(); } +c10::intrusive_ptr ProcessGroup::splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& desc) { + TORCH_CHECK( + ranks.size() > 0, + "Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group."); + TORCH_CHECK( + ranks.size() < static_cast(size_), + "the split group's size should be less than the world_size set by init_process_group"); + std::set ranks_set(ranks.begin(), ranks.end()); + TORCH_CHECK( + ranks_set.size() == ranks.size(), + "Split ranks should not have duplicates. Please provide a list of unique ranks to split the group."); + std::vector sorted_ranks = ranks; + std::sort(sorted_ranks.begin(), sorted_ranks.end()); + c10::intrusive_ptr newGroup; + // TODO: Figure out a better way for split group name. + std::string groupName = + c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks)); + for (const auto& pair : deviceTypeToBackendType_) { + c10::DeviceType deviceType = pair.first; + BackendType backendType = pair.second; + + auto parentBackend = getBackend(deviceType); + auto backendOpts = + opts.has_value() ? opts.value() : parentBackend->getBackendOptions(); + backendOpts->group_name = groupName; + backendOpts->timeout = + timeout.has_value() ? timeout.value() : backendOpts->timeout; + auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts); + if (splitBackend == nullptr) { + continue; + } + + // TODO: Figure out a better way for split group desc. + // TODO: We can add a new field in Backend::Options to specify the group + // desc + std::string groupDesc = desc.has_value() + ? desc.value() + : c10::str(getGroupDesc(), ":split:", incrementSplitCount()); + splitBackend->setGroupDesc(groupDesc); + + if (!newGroup) { + newGroup = c10::make_intrusive( + store_->clone(), splitBackend->getRank(), splitBackend->getSize()); + newGroup->setDefaultBackend(backendType_); + newGroup->setGroupName(groupName); + newGroup->setGroupDesc(groupDesc); + } + newGroup->setBackend(deviceType, backendType, splitBackend); + } + + return newGroup; +} + } // namespace c10d namespace { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index da4bf65f4f39d..5939f23e2972b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -170,6 +170,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } } + int64_t incrementSplitCount() { + return splitCounter_++; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -955,6 +959,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bound_device_id_ = device; } + // This creates a new subgroup using the specified ranks. + // The current rank must be included in the list of new_ranks. + virtual c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& groupDesc); + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. @@ -968,6 +980,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) BackendType backendType_; std::string pg_desc_; + int64_t splitCounter_; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 0df6073c5d2d5..30301524bc575 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -697,6 +697,35 @@ const std::vector& ProcessGroupGloo::groupRanks() const { return options_->global_ranks_in_group; } +c10::intrusive_ptr ProcessGroupGloo::splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + glooOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto store = std::dynamic_pointer_cast(store_); + TORCH_CHECK( + store != nullptr, + "store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore."); + auto pg = c10::make_intrusive( + store->_getStore()->clone(), groupRank, ranks.size(), glooOpts); + return c10::static_intrusive_pointer_cast(pg); +} + void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index e5f1ca7402882..0ba2d416aedff 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend { } #endif + const c10::intrusive_ptr<::c10d::Store>& _getStore() const { + return store_; + } + protected: c10::intrusive_ptr<::c10d::Store> store_; }; @@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend { } std::vector global_ranks_in_group; - std::string group_name; std::vector> devices; int threads; }; @@ -301,6 +304,14 @@ class TORCH_API ProcessGroupGloo : public Backend { } } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + + c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) override; + const std::vector& groupRanks() const; c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fda3879a8e8ca..3dc7abbb7e54c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1311,6 +1311,45 @@ void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } +c10::intrusive_ptr ProcessGroupNCCL::splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) { + auto deviceIdx = guessDeviceId(); + TORCH_CHECK( + deviceIdx >= 0, + "ProcessGroupNCCL::splitBackend: rank ", + rank_, + " has no device is bound to this rank."); + auto device = at::Device(at::DeviceType::CUDA, deviceIdx); + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + // This rank is not in the new group, so no_color split should be called + performNocolorSplit(device); + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + ncclOpts->split_from = + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this); + ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto color = genNcclSplitColor(ranks); + ncclOpts->split_color = color; + auto pg = c10::make_intrusive( + store_->clone(), groupRank, ranks.size(), ncclOpts); + pg->eagerConnectSingleDevice(device); + return c10::static_intrusive_pointer_cast(pg); +} + bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index bf7ac47d8ed18..d7bb02e912c81 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -541,7 +541,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Optional "parent" backend and color to create communicators from // via `ncclCommSplit` - std::shared_ptr split_from; + c10::intrusive_ptr split_from; // Color to use for `ncclCommSplit`, values: // * Non-negative value: in group; // * NCCL_SPLIT_NOCOLOR (-1): not in group; @@ -562,7 +562,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { int split_color{-2}; #endif std::vector global_ranks_in_group; - std::string group_name; }; // Helper class related to TORCH_NCCL_DESYNC_DEBUG @@ -804,6 +803,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { return options_; } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + const std::string getBackendName() const override { return std::string(NCCL_BACKEND_NAME); } @@ -972,6 +975,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { void enableCollectivesTiming() override; + c10::intrusive_ptr splitBackend( + const std::vector& ranks, + const c10::intrusive_ptr opts) override; + // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 3355d0feebfbf..854ea596aba8f 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -151,6 +151,21 @@ class PyProcessGroup : public ProcessGroup { group_desc); } + c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& group_desc) override { + PYBIND11_OVERRIDE( + c10::intrusive_ptr, /* Return type */ + ProcessGroup, /* Parent class */ + splitGroup, /* Name of function in C++ */ + ranks, + timeout, + opts, + group_desc); + } + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0121bd6fd94bd..5dfc99a893c7d 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2063,6 +2063,14 @@ communication mechanism. .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") + .def( + "split_group", + &::c10d::ProcessGroup::splitGroup, + py::arg("ranks"), + py::arg("timeout") = std::nullopt, + py::arg("opts") = std::nullopt, + py::arg("groupDesc") = std::nullopt, + py::call_guard()) .def( "abort", &::c10d::ProcessGroup::abort, From bcf50636ba1b93a833267c645d887888df06e9ea Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Mon, 14 Jul 2025 21:09:38 +0000 Subject: [PATCH 016/457] [CI] Removing --user flag from all pip install commands (#154900) Related to https://github.com/pytorch/pytorch/issues/148335 python virtualenv doesn't support using `--user` flag: ``` ERROR: Can not perform a '--user' install. User site-packages are not visible in this virtualenv. + python3 -m pip install --progress-bar off --user ninja==1.10.2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/154900 Approved by: https://github.com/jeffdaily Co-authored-by: Jithun Nair --- .ci/caffe2/test.sh | 6 +++--- .ci/onnx/test.sh | 2 +- .ci/pytorch/common_utils.sh | 18 +++++++++--------- .ci/pytorch/test.sh | 8 ++++---- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.ci/caffe2/test.sh b/.ci/caffe2/test.sh index eaef1e3ebf88a..7d1ce2fb4fa10 100755 --- a/.ci/caffe2/test.sh +++ b/.ci/caffe2/test.sh @@ -5,7 +5,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" if [[ ${BUILD_ENVIRONMENT} == *onnx* ]]; then pip install click mock tabulate networkx==2.0 - pip -q install --user "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx" + pip -q install "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx" fi # Skip tests in environments where they are not built/applicable @@ -147,8 +147,8 @@ export DNNL_MAX_CPU_ISA=AVX2 if [[ "${SHARD_NUMBER:-1}" == "1" ]]; then # TODO(sdym@meta.com) remove this when the linked issue resolved. # py is temporary until https://github.com/Teemu/pytest-sugar/issues/241 is fixed - pip install --user py==1.11.0 - pip install --user pytest-sugar + pip install py==1.11.0 + pip install pytest-sugar # NB: Warnings are disabled because they make it harder to see what # the actual erroring test is "$PYTHON" \ diff --git a/.ci/onnx/test.sh b/.ci/onnx/test.sh index a7d3b72c62a7e..d42ca2c218dec 100755 --- a/.ci/onnx/test.sh +++ b/.ci/onnx/test.sh @@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # TODO: This can be removed later once vision is also part of the Docker image - pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" + pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" # NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 9c0e5242f433c..3dbc2ece9e70b 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -127,9 +127,9 @@ function install_torchaudio() { if [[ "$1" == "cuda" ]]; then # TODO: This is better to be passed as a parameter from _linux-test workflow # so that it can be consistent with what is set in build - TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}" + TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" else - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" fi } @@ -139,8 +139,8 @@ function install_torchtext() { local text_commit data_commit=$(get_pinned_commit data) text_commit=$(get_pinned_commit text) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/data.git@${data_commit}" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/text.git@${text_commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/data.git@${data_commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/text.git@${text_commit}" } function install_torchvision() { @@ -153,7 +153,7 @@ function install_torchvision() { echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so fi - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/vision.git@${commit}" if [ -n "${LD_PRELOAD}" ]; then LD_PRELOAD=${orig_preload} fi @@ -173,7 +173,7 @@ function install_torchrec_and_fbgemm() { if [[ "$BUILD_ENVIRONMENT" == *rocm* ]] ; then # install torchrec first because it installs fbgemm nightly on top of rocm fbgemm - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" pip_uninstall fbgemm-gpu-nightly pip_install tabulate # needed for newer fbgemm @@ -190,8 +190,8 @@ function install_torchrec_and_fbgemm() { rm -rf fbgemm else # See https://github.com/pytorch/pytorch/issues/106971 - CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 --user "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" + pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" fi } @@ -234,7 +234,7 @@ function checkout_install_torchbench() { function install_torchao() { local commit commit=$(get_pinned_commit torchao) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git@${commit}" + pip_install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${commit}" } function print_sccache_stats() { diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 7b7e1970f72e6..77004a1764850 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -201,7 +201,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then # JIT C++ extensions require ninja. - pip_install --user "ninja==1.10.2" + pip_install "ninja==1.10.2" # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" @@ -496,7 +496,7 @@ DYNAMO_BENCHMARK_FLAGS=() pr_time_benchmarks() { - pip_install --user "fbscribelogger" + pip_install "fbscribelogger" TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" @@ -1471,8 +1471,8 @@ test_bazel() { test_benchmarks() { if [[ "$BUILD_ENVIRONMENT" == *cuda* && $TEST_CONFIG != *nogpu* ]]; then - pip_install --user "pytest-benchmark==3.2.3" - pip_install --user "requests" + pip_install "pytest-benchmark==3.2.3" + pip_install "requests" BENCHMARK_DATA="benchmarks/.data" mkdir -p ${BENCHMARK_DATA} pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_default.json --fuser=default --executor=default From 194539e9c33dc793fe67fbb68c7cee12f399e276 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Mon, 14 Jul 2025 22:09:31 +0000 Subject: [PATCH 017/457] Address NaNs if SDPA is called with all values masked from query (#157727) Fixes #156707 Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation. These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157727 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/operations/Attention.mm | 16 +++++++++++++++- test/test_mps.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/Attention.mm b/aten/src/ATen/native/mps/operations/Attention.mm index 3f3c6b309fd66..69ec9af055baf 100644 --- a/aten/src/ATen/native/mps/operations/Attention.mm +++ b/aten/src/ATen/native/mps/operations/Attention.mm @@ -114,8 +114,22 @@ graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; } + + // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) + // Overwrites expected NANs in sm with zeros. + auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType]; + auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil]; + auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil]; + auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil]; + auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType]; + auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil]; - auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil]; + MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask + truePredicateTensor:zeroTensor + falsePredicateTensor:sm + name:nil]; + + auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil]; graph->qTensor = qTensor; graph->kTensor = kTensor; graph->vTensor = vTensor; diff --git a/test/test_mps.py b/test/test_mps.py index 89e68b1718528..d9e4b7a9f037c 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9257,6 +9257,18 @@ def test_sdpa_mask_fp16_L6(self): def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) + # Regression test from: https://github.com/pytorch/pytorch/issues/156707 + @parametrize("dtype", [torch.float16, torch.float32]) + def test_sdpa_full_mask(self, dtype): + q = torch.randn(1, 1, 2, 4, dtype=dtype) + k = torch.randn(1, 1, 2, 4, dtype=dtype) + v = torch.randn(1, 1, 2, 4, dtype=dtype) + mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool) + + out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps')) + self._compare_tensors(out_mps.cpu(), out_cpu) + @parametrize("dtype", [torch.float16, torch.float32]) def test_sdpa_3d_input(self, dtype): head_num, seq_len, embed_dim = 16, 16, 80 From 9345279c6ebdbad95b7b53bc2cb6f63a4e57b2cc Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Mon, 14 Jul 2025 22:14:52 +0000 Subject: [PATCH 018/457] skip inductor/test_torchinductor_opinfo in windows (#158225) During enabling inductor CI in Windows, `test_torchinductor_opinfo.py` cost too many time (about 12 hours). This UT was seriously exceeding the time limit of CI. The compiler building was slower 4x in Windows than Linux after analyzing. Thus, we decide to skip the UT temporary and @xuhancn will keep searching the solution of compiler building in Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158225 Approved by: https://github.com/jansel Co-authored-by: Xu Han --- test/inductor/test_torchinductor_opinfo.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 510e827705f7e..8abd17aab2f88 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -30,7 +30,9 @@ ) from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( + IS_CI, IS_MACOS, + IS_WINDOWS, IS_X86, skipCUDAMemoryLeakCheckIf, skipIfCrossRef, @@ -67,6 +69,15 @@ sys.exit(0) raise +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : improve the compiler build performance on windows. + sys.stderr.write( + "This UT is too slow on windows, and will cause out of time in CI. So skip it now.\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("skip slow test") + bf16 = torch.bfloat16 # not tested f64 = torch.float64 f32 = torch.float32 From c062550a3598d27c2d6572db7c0f4ff90a84cc84 Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Mon, 14 Jul 2025 10:59:01 -0700 Subject: [PATCH 019/457] [PT2][fusion] ban fusions with large accumulated reads (#157563) **Problem:** Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet ``` total = torch.rand(N, N) for _ in range(r): x = torch.rand(N, N) total = total + x ``` The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like: ``` x_1 = torch.rand(N, N) x_2 = torch.rand(N, N) ... x_r = torch.rand(N, N) total = x_1 + x_2 + ... + x_r ``` Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient. [internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details **Solution:** Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile. * During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated. * During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`. **Results:** For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563 Approved by: https://github.com/jansel, https://github.com/mlazos --- .../pr_time_benchmarks/expected_results.csv | 86 +++---------------- test/inductor/test_memory.py | 51 +++++++++++ test/inductor/test_online_softmax.py | 8 +- torch/_inductor/choices.py | 4 + torch/_inductor/config.py | 1 + torch/_inductor/graph.py | 21 +++++ torch/_inductor/ir.py | 11 +++ torch/_inductor/memory.py | 13 +-- torch/_inductor/scheduler.py | 29 +++---- 9 files changed, 118 insertions(+), 106 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index edc9d0f73d161..9e5521f94b43e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,89 +1,23 @@ -add_loop_eager,compile_time_instruction_count,3017000000,0.015 - - - +add_loop_eager,compile_time_instruction_count,2996000000,0.015 add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 - - - -add_loop_inductor,compile_time_instruction_count,29490000000,0.015 - - - -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 - - - -add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 - - - +add_loop_inductor,compile_time_instruction_count,33090000000,0.015 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025 +add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015 basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 - - - -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 - - - -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 - - - -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 - - - +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2 update_hint_regression,compile_time_instruction_count,1673000000,0.02 - - - sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 - - - -symint_sum,compile_time_instruction_count,3166000000,0.015 - - - +symint_sum,compile_time_instruction_count,3184000000,0.015 symint_sum_loop,compile_time_instruction_count,4202000000,0.015 - - - aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 - - - aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 - - - aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 - - - aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 - - - aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 - - - aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 - - - -mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 - - - -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 - - - +mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015 basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 - - - basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index eaff539f7a493..489ba4ffeb0df 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -306,6 +306,57 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) + def test_fusion_acc_large_reads(self): + def f(x, y, z): + res = torch.zeros_like(x[0]) + for i in range(4): + temp = torch.matmul(x, y) + z + res = res + temp + return res + + N = 128 + x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + + # CASE 1: no restriction on the amount of accumulation + with config.patch({"realize_acc_reads_size_threshold": float("inf")}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") + .run(code) + ) + + # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) + # at most 12 / 4 = 3 reads can be accumulated during fusion + with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") + .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") + .run(code) + ) + + # CASE 3: no such fusion allowed + with config.patch({"realize_acc_reads_size_threshold": N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf1, arg2_1,") + .check("triton_poi_fused_add_0.run(buf3, arg2_1,") + .check("triton_poi_fused_add_0.run(buf4, buf3,") + .check("triton_poi_fused_add_0.run(buf6, arg2_1,") + .check("triton_poi_fused_add_0.run(buf7, buf6,") + .check("triton_poi_fused_add_0.run(buf9, arg2_1,") + .check("triton_poi_fused_add_0.run(buf10, buf9,") + .run(code) + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 798d86b0dd617..37959c241113f 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -13,6 +13,7 @@ instantiate_parametrized_tests, IS_LINUX, parametrize, + serialTest, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA @@ -77,12 +78,17 @@ def f(x): out, source_codes = run_and_get_code(f, x) return source_codes[0] + @serialTest() def test_codegen_3pass_softmax_due_to_disable(self): - with inductor_config.patch(online_softmax=False): + with inductor_config.patch( + online_softmax=False, + realize_acc_reads_size_threshold=float("inf"), + ): wrapper_code = self.get_softmax_wrapper() self.assertEqual(wrapper_code.count("for r0_offset in"), 3) + @serialTest() @parametrize("V", [2048, 50304]) @parametrize("use_log_softmax", [False, True]) def test_codegen_online_softmax(self, use_log_softmax, V): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b7bab02da5e4b..9096ba6dd0393 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,6 +365,10 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False + if scheduler.fusion_accumulate_large_reads(node1, node2): + WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") + return False + return True @staticmethod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 60e8b259368cf..826324b6a2044 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -574,6 +574,7 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 +realize_acc_reads_size_threshold = 3 * (1024**3) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f28..ac299d5b0c2d0 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,6 +123,7 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -485,6 +486,9 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() + # Cache for dep size hints to avoid expensive recomputation + self.dep_size_hint_cache: dict[Dep, int] = {} + def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -570,6 +574,23 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) + def get_dep_size_hint(self, dep: Dep) -> int: + """ + Get the size hint for a dependency with caching to avoid expensive recomputation. + """ + if dep not in self.dep_size_hint_cache: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.dep_size_hint_cache[dep] = res + return self.dep_size_hint_cache[dep] + def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1edbb214ae2ad..d6dd82aa52f2d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7829,6 +7829,10 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): + """ + StorageBox allow in-place mutation of Tensors + """ + def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7878,10 +7882,17 @@ def realize_hint(self) -> None: ): self.realize() + def has_accumulated_enough_reads_by_size(self) -> bool: + return ( + sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) + > config.realize_acc_reads_size_threshold + ) + def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() + or self.has_accumulated_enough_reads_by_size() ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 5601bc4adcee4..d287208419a9f 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,19 +78,8 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ - # this function is copied from torch/_inductor/scheduler.py - # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - return res + return V.graph.get_dep_size_hint(dep) # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5c7a16d25bc64..34f15869085f0 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,15 +2051,12 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ - __dep_size_hint_cache: dict[Dep, int] - def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() - self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3505,6 +3502,17 @@ def _find_single_user_inputs( return True return False + def fusion_accumulate_large_reads( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( + node1.read_writes.writes | node2.read_writes.writes + ) + return ( + sum(self.dep_size_hint(dep) for dep in all_reads) + > config.realize_acc_reads_size_threshold + ) + def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4010,20 +4018,7 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - res = 0 - if dep not in self.__dep_size_hint_cache: - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.__dep_size_hint_cache[dep] = res - else: - res = self.__dep_size_hint_cache[dep] - return res + return V.graph.get_dep_size_hint(dep) def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode From 38371f693b07a485705119407da2e5dc64cec4eb Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Mon, 14 Jul 2025 15:47:43 -0700 Subject: [PATCH 020/457] ci: Switch lintrunner-noclang to use linter image (#158261) This changes the image the lintrunner jobs utilizes to be the base linter image instead of the CUDA image. This is done to reduce the image size and speed up the build time. This was switched in https://github.com/pytorch/pytorch/pull/110502 when clang used to run in the lintrunner jobs but it is now split out so we can use the default image for non-clang jobs. Difference in pull time (from running job): ~5min --> ~1min (80% reduction), this should result in an overall runtime decrease of ~25min --> ~20min (20% reduction) Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/158261 Approved by: https://github.com/Camyll, https://github.com/ZainRizvi, https://github.com/atalman, https://github.com/Skylion007 --- .ci/docker/linter/Dockerfile | 2 ++ .github/workflows/lint.yml | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/docker/linter/Dockerfile b/.ci/docker/linter/Dockerfile index 0fdfac678d40f..95d08ffea051d 100644 --- a/.ci/docker/linter/Dockerfile +++ b/.ci/docker/linter/Dockerfile @@ -27,5 +27,7 @@ COPY ./common/install_linter.sh install_linter.sh RUN bash ./install_linter.sh RUN rm install_linter.sh +RUN chown -R jenkins:jenkins /var/lib/jenkins/ci_env + USER jenkins CMD ["bash"] diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d0a2fda509ef3..0fca34048196a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,7 +26,6 @@ jobs: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} - lintrunner-clang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main needs: get-label-type @@ -50,7 +49,7 @@ jobs: with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter + docker-image: ci-image:pytorch-linux-jammy-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 From 48315181c75e43cab5957197d42e053d66b3fe1c Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 14 Jul 2025 23:07:45 +0000 Subject: [PATCH 021/457] [CI] Do not run inductor rocm on ciflow/inductor (#158162) Petition to only run inductor-rocm on ciflow/inductor-rocm and not ciflow/inductor because it's a long pole for TTS image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158162 Approved by: https://github.com/seemethere --- .github/workflows/inductor-rocm.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 4241854aa3278..b1bb7972d67de 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -7,7 +7,6 @@ on: - release/* tags: - ciflow/inductor-rocm/* - - ciflow/inductor/* workflow_dispatch: concurrency: From 08799217aeb17128d89d675ce5b537761286417a Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 14 Jul 2025 23:07:46 +0000 Subject: [PATCH 022/457] [CI] Move main branch rocm binary builds to its own workflow (#158161) Petition to move out of ciflow/trunk and into ciflow/rocm because it's a long pole for TTS image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158161 Approved by: https://github.com/seemethere --- .github/scripts/generate_ci_workflows.py | 30 +++- .../generated-linux-binary-manywheel-main.yml | 92 ------------ ...rated-linux-binary-manywheel-rocm-main.yml | 137 ++++++++++++++++++ 3 files changed, 166 insertions(+), 93 deletions(-) create mode 100644 .github/workflows/generated-linux-binary-manywheel-rocm-main.yml diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 55cb02504ea45..4df6150f97655 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -22,6 +22,7 @@ LABEL_CIFLOW_PERIODIC = "ciflow/periodic" LABEL_CIFLOW_BINARIES_LIBTORCH = "ciflow/binaries_libtorch" LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel" +LABEL_CIFLOW_ROCM = "ciflow/rocm" @dataclass @@ -146,13 +147,35 @@ class OperatingSystem: ), ] +ROCM_SMOKE_WORKFLOWS = [ + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_variant="rocm", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + arches=["6.4"], + python_versions=["3.9"], + ), + ciflow_config=CIFlowConfig( + labels={ + LABEL_CIFLOW_BINARIES, + LABEL_CIFLOW_BINARIES_WHEEL, + LABEL_CIFLOW_ROCM, + }, + isolated_workflow=True, + ), + branches="main", + ), +] + LINUX_BINARY_SMOKE_WORKFLOWS = [ BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, - arches=["12.6", "12.8", "12.9", "6.4"], + arches=["12.6", "12.8", "12.9"], python_versions=["3.9"], ), branches="main", @@ -387,6 +410,11 @@ def main() -> None: jinja_env.get_template("linux_binary_build_workflow.yml.j2"), S390X_BINARY_BUILD_WORKFLOWS, ), + ( + # Give rocm it's own workflow file + jinja_env.get_template("linux_binary_build_workflow.yml.j2"), + ROCM_SMOKE_WORKFLOWS, + ), ( jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_SMOKE_WORKFLOWS, diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 8e27aca1150b6..d1e89bb6e2d85 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -182,95 +182,3 @@ jobs: runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-rocm6_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: 6.4 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-rocm6_4 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-rocm6_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-rocm6_4-build - - get-label-type - runs-on: linux.rocm.gpu.mi250 - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: 6.4 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False - DESIRED_PYTHON: "3.9" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v4.1.7 - name: Download Build Artifacts - with: - name: manywheel-py3_9-rocm6_4 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: actions/checkout@v4 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - show-progress: false - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: configure aws credentials - id: aws_creds - if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only - aws-region: us-east-1 - role-duration-seconds: 18000 - - name: Calculate docker image - id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@main - with: - docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} - docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 - docker-build-dir: .ci/docker - working-directory: pytorch - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - env: - DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm diff --git a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml new file mode 100644 index 0000000000000..b6b63c4e38d5e --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml @@ -0,0 +1,137 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel-rocm + + +on: + push: + branches: + - main + tags: + - 'ciflow/binaries/*' + - 'ciflow/binaries_wheel/*' + - 'ciflow/rocm/*' + workflow_dispatch: + +permissions: + id-token: write + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel-rocm + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_4 + build_environment: linux-binary-manywheel-rocm + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm From 7cf31b4a426f3791af30159cea420687f347cd7a Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Sat, 12 Jul 2025 10:23:42 -0700 Subject: [PATCH 023/457] [dynamo] fix NamedTupleVariable cloning (#158190) FIXES https://github.com/pytorch/pytorch/issues/157945 ## Explanation 1. Some VTs add additional attrs e.g. NamedTupleVariable has "dynamic_attributes" https://github.com/pytorch/pytorch/blob/a0308edb6cdfd8983e80a499890d9f320556e844/torch/_dynamo/variables/lists.py#L1048-L1051 2. VT.clone passes everything by dict, includes "dynamic_attributes" https://github.com/pytorch/pytorch/blob/a0308edb6cdfd8983e80a499890d9f320556e844/torch/_dynamo/variables/base.py#L255-L259 3. Non-handled args become kwargs in VT's `__init__`, `super().__init__()` passes kwargs to Base VT https://github.com/pytorch/pytorch/blob/a0308edb6cdfd8983e80a499890d9f320556e844/torch/_dynamo/variables/lists.py#L1048-L1051 4. Base VT's `__init__` gets unexpected "dynamic_attributes" kwarg https://github.com/pytorch/pytorch/blob/a0308edb6cdfd8983e80a499890d9f320556e844/torch/_dynamo/variables/base.py#L609-L613 You could also let Base VT's `__init__` ignore additional kwargs, but that seemed a bit too permissive, and I don't think many VT's add these derived class only attrs. ## After fix ```python ===== __compiled_fn_1_7f9541ed_e166_43fe_8322_c5225ce4207f ===== /home/xmfan/core/miniconda3/envs/0712/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[4, 8, 6][48, 6, 1]cpu"): l_x_ = L_x_ # File: /home/xmfan/core/a/torchtitan/wtf.py:10 in forward, code: U, S = torch.linalg.svd(x)[:2] linalg_svd = torch._C._linalg.linalg_svd(l_x_); l_x_ = None U: "f32[4, 8, 8][64, 1, 8]cpu" = linalg_svd[0] S: "f32[4, 6][6, 1]cpu" = linalg_svd[1]; linalg_svd = None # File: /home/xmfan/core/a/torchtitan/wtf.py:11 in forward, code: reduced = U[:, :, :self.k] @ torch.diag_embed(S[:, :self.k]) getitem_3: "f32[4, 8, 5][64, 1, 8]cpu" = U[(slice(None, None, None), slice(None, None, None), slice(None, 5, None))]; U = None getitem_4: "f32[4, 5][6, 1]cpu" = S[(slice(None, None, None), slice(None, 5, None))]; S = None diag_embed: "f32[4, 5, 5][25, 5, 1]cpu" = torch.diag_embed(getitem_4); getitem_4 = None reduced: "f32[4, 8, 5][40, 5, 1]cpu" = getitem_3 @ diag_embed; getitem_3 = diag_embed = None return (reduced,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158190 Approved by: https://github.com/StrongerXi --- test/dynamo/test_repros.py | 19 +++++++++++++++++++ torch/_dynamo/variables/lists.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 89202b9037e5e..8636b3496d1b5 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7569,6 +7569,25 @@ def f(x): with mock.patch("torch.cuda.is_initialized", lambda: False): self.assertEqual(f(inp), inp + 2) + def test_named_tuple_vt_clone(self): + # https://github.com/pytorch/pytorch/issues/157945 + class SVDCompressor(nn.Module): + def __init__(self, k=10): + super().__init__() + self.k = k + + def forward(self, x): + U, S = torch.linalg.svd(x)[:2] + reduced = U[:, :, : self.k] @ torch.diag_embed(S[:, : self.k]) + return reduced + + input = torch.randn(4, 8, 6) + model = SVDCompressor(k=5) + + out1 = model(input.clone()) + out2 = torch.compile(model, backend="eager")(input.clone()) + self.assertEqual(out1, out2) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 93547c79e9564..55891ce1de243 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1045,10 +1045,10 @@ class NamedTupleVariable(TupleVariable): *TupleVariable._nonvar_fields, } - def __init__(self, items, tuple_cls, **kwargs) -> None: + def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls - self.dynamic_attributes = {} + self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes def is_namedtuple(self): return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( From b7def5ff1ca72fbb06350ffc75117efc68e149fb Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 15 Jul 2025 00:02:50 +0000 Subject: [PATCH 024/457] dist2: add support for passing custom configs directly to PG (#158147) This is intended to make it easier to have backend specific "hints" that can be provided by the user to hint about certain options. ```py import torch.distributed._dist2 as dist2 pg = dist2.new_group(backend="my_custom_backend", device=..., timeout=..., foo=1234, bar="1234") pg.allreduce(...) ``` Test plan: ``` pytest test/distributed/test_dist2.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158147 Approved by: https://github.com/fduwjj --- test/distributed/test_dist2.py | 14 ++++++-------- torch/distributed/_dist2.py | 27 +++++++++++++++------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index d5e925b4b2d0c..b4d6c6d02b35f 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -28,10 +28,14 @@ def test_context_manager(self): os.environ["MASTER_PORT"] = "29500" pg1 = dist2.new_group( - backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None + backend="gloo", + timeout=timedelta(seconds=60), + device="cpu", ) pg2 = dist2.new_group( - backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None + backend="gloo", + timeout=timedelta(seconds=60), + device="cpu", ) self.assertIsNone(dist2.current_process_group()) @@ -227,7 +231,6 @@ def new_group(self) -> torch.distributed.ProcessGroup: backend="gloo", timeout=timedelta(seconds=60), device=self.device, - pg_options=None, ) @@ -242,15 +245,10 @@ def new_group(self) -> torch.distributed.ProcessGroup: self.device = torch.device("cuda", self.rank) - from torch.distributed import ProcessGroupNCCL - - opts = ProcessGroupNCCL.Options() - return dist2.new_group( backend="nccl", timeout=timedelta(seconds=60), device=self.device, - pg_options=opts, ) diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index 8704b9015997d..7ef73a27a6dd3 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -16,7 +16,6 @@ from torch._C._distributed_c10d import ( _current_process_group, _set_process_group, - Backend, ProcessGroup, ReduceOp, Store, @@ -47,7 +46,7 @@ def __call__( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: ... @@ -71,11 +70,11 @@ def _gloo_factory( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: from torch.distributed import ProcessGroupGloo - assert pg_options is None, "Gloo backend does not support options" + assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" backend_class = ProcessGroupGloo(store, rank, world_size, timeout) backend_class._set_sequence_number_for_group() @@ -101,15 +100,18 @@ def _nccl_factory( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: from torch.distributed import ProcessGroupNCCL - assert isinstance(pg_options, ProcessGroupNCCL.Options) + opts = ProcessGroupNCCL.Options() + opts._timeout = timeout + for k, v in kwargs.items(): + if not hasattr(opts, k): + raise KeyError(f"Unknown option {k}") + setattr(opts, k, v) - pg_options._timeout = timeout - - backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options) + backend_class = ProcessGroupNCCL(store, rank, world_size, opts) backend_class._set_sequence_number_for_group() backend_class.eager_connect_single_device(device) @@ -128,7 +130,7 @@ def new_group( backend: str, timeout: timedelta, device: Union[str, torch.device], - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: """ Create a new process group with the given backend and options. This group is @@ -139,7 +141,8 @@ def new_group( backend: The backend to use for the process group. timeout: The timeout for collective operations. device: The device to use for the process group. - pg_options: The options to use for the process group. + **kwargs: All remaining arguments are passed to the backend constructor. + See the backend specific documentation for details. Returns: A new process group. @@ -152,7 +155,7 @@ def new_group( store, rank, world_size = next(iter(rendezvous("env://"))) store.set_timeout(timeout) - return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options) + return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs) def current_process_group() -> ProcessGroup: From 7e433d5f423248914c5e9838d3ea145db7964923 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Mon, 14 Jul 2025 09:28:26 -0700 Subject: [PATCH 025/457] [cutlass backend] cache a few things for codegen and properties (#158158) Differential Revision: [D78193404](https://our.internmc.facebook.com/intern/diff/D78193404/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158158 Approved by: https://github.com/ColinPeppler --- torch/_inductor/codecache.py | 1 + torch/_inductor/codegen/cuda/cuda_kernel.py | 14 ++++++++++++-- torch/_inductor/codegen/cuda/cuda_template.py | 5 +++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6ade10e63163c..78c47dcb082f6 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3757,6 +3757,7 @@ def get_kernel_binary_remote_cache( return None @classmethod + @lru_cache(None) def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: """ Writes source code into a file with dst_file_ext as the file extension. diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index f419ada67e1a3..224f0d2a423dc 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -631,16 +631,26 @@ def hash_key(self) -> str: """ Return kernel hash key that does not depend on swizzle. """ + swizzle_str: str = ( + str(self.info_kwargs.get("swizzle")) + if isinstance(self.info_kwargs, dict) + else "None" + ) return "-".join( [ self.category, self.bmreq.hash_key, - str(self.info_dict().get("swizzle")), + swizzle_str, ] ) def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: - """Information returned here is logged to the autotune log file when that is enabled.""" + """ + Information returned here is logged to the autotune log file when that is enabled. + + In general, we should avoid calling this function as it is expensive to compute, + and can add up very fast. + """ if self.info_kwargs is not None and "op" in self.info_kwargs: op: Any = self.info_kwargs["op"] return { diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 7ed67b0daa49f..07ee9f127580f 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -65,6 +65,11 @@ def __init__( self.input_reorder = input_reorder self.layout = layout + @classmethod + @functools.lru_cache(None) + def _template_from_string(cls, source: str) -> Any: + return KernelTemplate._template_from_string(source) + @staticmethod def supports_epilogue_fusion(op: GemmOperation) -> bool: return False From 1c6057fd179b0373686a790b0a0b7fc68fe7f27d Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 15 Jul 2025 00:50:03 +0000 Subject: [PATCH 026/457] add eq function to NodeSource (#158170) Summary: add eq function to NodeSouce by comparing their dict representation. Test Plan: ci Rollback Plan: Differential Revision: D78200762 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158170 Approved by: https://github.com/ezyang, https://github.com/yushangdi --- test/fx/test_fx_traceback.py | 56 ++++++++++++++++++++++++++++++++++++ torch/fx/traceback.py | 5 ++++ 2 files changed, 61 insertions(+) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 6306daa571bd0..e11ee19daaac4 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -2,6 +2,7 @@ import torch from torch._inductor.compile_fx import aot_export_module +from torch.export import default_decompositions from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction from torch.testing._internal.common_utils import TestCase @@ -64,6 +65,24 @@ def test_node_source(self): }, ) + # Test two node sources are same + node_source1 = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + node_source2 = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + self.assertEqual(node_source1, node_source2) + + # Test two node sources are not same + node_source3 = NodeSource( + node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE + ) + node_source4 = NodeSource( + node=None, pass_name="test_pass_2", action=NodeSourceAction.CREATE + ) + self.assertNotEqual(node_source3, node_source4) + def test_graph_provenance(self): def check_node_source(node_source_dict, name, pass_name, action): self.assertEqual(node_source_dict["name"], name) @@ -95,6 +114,43 @@ def forward(self, x): model = Model() example_inputs = (torch.randn(8, 10),) ep = torch.export.export(model, example_inputs, strict=True) + + decomposed_ep = ep.run_decompositions(default_decompositions()) + # node decomposed from same ancestor node should have same from_node info + for node in decomposed_ep.graph.nodes: + if node.op not in {"placeholder", "output"}: + assert "from_node" in node.meta + + node_name_to_from_node = { + node.name: node.meta["from_node"] + for node in decomposed_ep.graph.nodes + if node.op not in {"placeholder", "output"} + } + same_ancestor_nodes = { + "permute": "addmm", + "addmm": "permute", + "permute_1": "addmm_1", + "addmm_1": "permute_1", + } + + for node_name_1 in node_name_to_from_node: + for node_name_2 in node_name_to_from_node: + if node_name_2 in { + node_name_1, + same_ancestor_nodes[node_name_1] + if node_name_1 in same_ancestor_nodes + else None, + }: + self.assertTrue( + node_name_to_from_node[node_name_1] + == node_name_to_from_node[node_name_2] + ) + else: + self.assertTrue( + node_name_to_from_node[node_name_1] + != node_name_to_from_node[node_name_2] + ) + gm = ep.module() provenance = get_graph_provenance_json(gm.graph) self.assertEqual( diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 3ec156005a015..9f316191a2302 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -123,6 +123,11 @@ def to_dict(self) -> dict: "from_node": [node.to_dict() for node in self.from_node], } + def __eq__(self, other: object): + if not isinstance(other, NodeSource): + return False + return self.to_dict() == other.to_dict() + @compatibility(is_backward_compatible=False) @contextmanager From ef4cca2d79eba61441da46906b30f8f6165cc455 Mon Sep 17 00:00:00 2001 From: James Wu Date: Sun, 13 Jul 2025 12:41:15 -0700 Subject: [PATCH 027/457] [precompile] Increment frame and add compile ids when loading packages (#158028) When loading a package and calling package.install(backends), we create a new frame and compile id for each package load, so that tlparse and chromium events still show compile times on warm start. There is an argument for not doing this in AOT precompile, as no "compile" occurs. So for now, we put it in `package.install`, which hopefully won't be a thing for AOT precompile. ## Recompiles Recompiles get saved to the same frame and code entry, so on warm start, each recompile will get collapsed into the same entry. Therefore, dynamo compiles that have recompiles on cold start (0/0, 0/1, 0/2, etc) will all get collapsed into a single compile id (0/0), as warm start will load all of the entries properly. ## Graph breaks Graph breaks get their own compile id, and therefore their own code entry. These are replicated on warm start, so if cold start you had 4 different graphs (and therefore 4 compile ids), you'll have 4 compile ids on warm start as well. ## Test plan Added a frame counter check to existing unit tests for automatic dynamic, showing that old and new frame counter between old and new load is the same. This is the chromium event for test_automatic_dynamo_graph_breaks_device_cuda: ``` python test/dynamo/test_package.py -k test_automatic_dynamo_graph_breaks_device_cuda ``` image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158028 Approved by: https://github.com/oulgen --- test/dynamo/test_package.py | 19 ++++- torch/_dynamo/convert_frame.py | 43 ++++++----- torch/_dynamo/package.py | 127 ++++++++++++++++++++------------- 3 files changed, 120 insertions(+), 69 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index d75ea975cb741..3160007774090 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -38,7 +38,9 @@ def setUp(self): DynamoCache.clear() PrecompileContext.clear() - def _save_and_reload(self, expected_backends, expected_dynamo): + def _save_and_reload( + self, expected_backends, expected_dynamo, expected_autotune=None + ): """ Serializes all artifacts, clears all caches, then reloads the serialized artifact Simulates a new process. @@ -54,6 +56,8 @@ def _save_and_reload(self, expected_backends, expected_dynamo): len(cache_info.precompile_aot_autograd_artifacts), expected_backends ) self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo) + if expected_autotune is not None: + self.assertEqual(len(cache_info.autotune_artifacts), expected_autotune) torch._dynamo.reset() DynamoCache.clear() @@ -377,7 +381,7 @@ def fn2(x): DynamoCache.save(package1) DynamoCache.save(package2) - + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) # These should exist because of populate_caches @@ -388,6 +392,7 @@ def fn2(x): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -411,6 +416,7 @@ def fn2(x): result = [compiled_fn1(arg1), compiled_fn2(arg2)] self.assertEqual(expected, result) DynamoCache.clear() + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) @@ -420,6 +426,7 @@ def fn2(x): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -439,6 +446,7 @@ def fn(x): # Should cause a recompile expected2 = compiled_fn(arg2) + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=1) @@ -451,6 +459,7 @@ def fn(x): compiled_fn(arg3) self.assertEqual(result1, expected1) self.assertEqual(result2, expected2) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -486,6 +495,7 @@ def guard_filter_fn(guards): for args in args_list: compiled_fn(*args) + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=8, expected_dynamo=1) compiled_fn = torch._dynamo.optimize( @@ -494,6 +504,8 @@ def guard_filter_fn(guards): with torch.compiler.set_stance("fail_on_recompile"): for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) + # Should have same number of frames as on cold start + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -512,6 +524,7 @@ def fn(x): compiled_fn = torch.compile(fn) expected1 = compiled_fn(arg1) expected1.sum().backward() + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=1, expected_dynamo=1) @@ -521,6 +534,8 @@ def fn(x): expected2 = compiled_fn(arg2) expected2.sum().backward() + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index fe547691add68..8fe9e3aaf13a6 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -495,6 +495,29 @@ def _is_error_on_graph_break(tx: Optional[InstructionTranslator]) -> bool: return tx.error_on_graph_break +def get_compile_id( + frame_state: dict[str, Union[int, FrameStateSizeEntry]], +) -> CompileId: + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + assert isinstance(frame_id, int) + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compiled_autograd_id = None + if prior := CompileContext.current_compile_id(): + compiled_autograd_id = prior.compiled_autograd_id + return CompileId( + compiled_autograd_id=compiled_autograd_id, + frame_id=frame_id, + frame_compile_id=frame_compile_id, + ) + + class ConvertFrameAssert: def __init__( self, @@ -610,24 +633,8 @@ def __call__( global initial_global_state initial_global_state = GlobalStateGuard() - global FRAME_COUNTER - if "_id" not in frame_state: - frame_state["_id"] = FRAME_COUNTER - FRAME_COUNTER += 1 - frame_id = frame_state["_id"] - assert isinstance(frame_id, int) - - frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] - FRAME_COMPILE_COUNTER[frame_id] += 1 - - compiled_autograd_id = None - if prior := CompileContext.current_compile_id(): - compiled_autograd_id = prior.compiled_autograd_id - compile_id = CompileId( - compiled_autograd_id=compiled_autograd_id, - frame_id=frame_id, - frame_compile_id=frame_compile_id, - ) + compile_id = get_compile_id(frame_state) + frame_id = compile_id.frame_id signpost_event( "dynamo", diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 5bf63c2544cde..2a33a019b6cca 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -32,6 +32,7 @@ from torch.compiler._cache import CacheArtifactFactory from .bytecode_transformation import get_code_keys +from .utils import dynamo_timed, increment_frame logger = logging.getLogger(__name__) @@ -379,60 +380,87 @@ def install(self, backends: dict[_BackendId, Any]) -> None: 3. Install the precompiled cache entries to ExtraStates on the code object. """ from torch._C._dynamo.eval_frame import _load_precompile_entry + from torch._dynamo.convert_frame import get_compile_id + from torch._guards import compile_context, CompileContext from .output_graph import get_builtins_dict self.uninstall() - for code, entry in self._codes.items(): - module = sys.modules[entry.python_module] - for alias, module_name in entry.import_sources.items(): - self._install_global( - module, alias, importlib.import_module(module_name) - ) - for function_name in entry.function_names: - fn = types.FunctionType(code, module.__dict__, function_name) - self._install_global(module, function_name, fn) - for backend_id in entry.backend_ids: - if backend_id not in backends: - raise RuntimeError( - f"Backend {backend_id} is not found in the given backends" + # Each code represents a new compile frame + # recompiles on the same frame are all saved + # under the same cache entry, so we don't have recompile ids + # i.e. If cold start had 0/0, 0/1, 1/0, 1/1, these would be + # collapsed into 0/0, 1/0 on warm. + increment_frame() + compile_id = get_compile_id(frame_state={}) + with ( + compile_context(CompileContext(compile_id)), + dynamo_timed( + "_compile.compile_inner", + phase_name="entire_frame_compile", + dynamo_compile_column_us="dynamo_cumulative_compile_time_us", + # TODO: save all relevant compilation metrics + metadata={ + "frame_key": str(torch._dynamo.utils.curr_frame), + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + }, + ), + ): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) ) - backend = backends[backend_id] - self._install_global( - module, - backend_id, - torch._dynamo.disable(backend), - ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + with dynamo_timed( + "after_deserialization", phase_name="backend_compile" + ): + backend = backends[backend_id].after_deserialization() + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) - for code, entry in self._codes.items(): - for guarded_code in entry.guarded_codes: - guards_state = pickle.loads(guarded_code.guards_state) - runtime_global_scope = sys.modules[entry.python_module].__dict__ - # The installed builtins dict might be absent from the runtime - # while loading guards. Populate it if it's missing. - if ( - builtin_dict_name - := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals - ): - builtins_dict = get_builtins_dict(runtime_global_scope) - if builtin_dict_name in runtime_global_scope: - assert runtime_global_scope[builtin_dict_name] is builtins_dict - else: - runtime_global_scope[builtin_dict_name] = builtins_dict - assert isinstance(guards_state, torch._dynamo.guards.GuardsState) - check_fn_manager = torch._dynamo.guards.CheckFunctionManager( - code, - guards_state.output_graph, - guards_serialization_mode="load", - shape_code_parts=guards_state.shape_code_parts, - runtime_global_scope=runtime_global_scope, - ) - _load_precompile_entry( - code, - check_fn_manager.guard_manager, - SerializedCode.to_code_object(guarded_code.dynamo_code), - ) + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + runtime_global_scope = sys.modules[entry.python_module].__dict__ + # The installed builtins dict might be absent from the runtime + # while loading guards. Populate it if it's missing. + if ( + builtin_dict_name + := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals + ): + builtins_dict = get_builtins_dict(runtime_global_scope) + if builtin_dict_name in runtime_global_scope: + assert ( + runtime_global_scope[builtin_dict_name] is builtins_dict + ) + else: + runtime_global_scope[builtin_dict_name] = builtins_dict + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=runtime_global_scope, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) def cache_entry(self) -> _DynamoCacheEntry: self.validate() @@ -556,7 +584,7 @@ def load_cache_entry( PrecompileContext.record_artifact( backend.type(), key=backend.key, content=backend.content ) - backend_content[backend_id] = backend.after_deserialization() + backend_content[backend_id] = backend return cache_entry, backend_content @@ -683,7 +711,8 @@ def load( path = os.path.join(self.path_prefix, key) if os.path.exists(path): try: - return super().load_cache_entry(key) + result = super().load_cache_entry(key) + return result except Exception as e: logger.warning("Failed to load package from path %s: %s", path, str(e)) return None From a5e68814d556cf67c6511876410970dd08c3dd6d Mon Sep 17 00:00:00 2001 From: Arsh Zahed Date: Tue, 15 Jul 2025 00:53:57 +0000 Subject: [PATCH 028/457] Allow dynamic shapes for DTensor slice (#157953) This PR allows for symints in `gen_slice_strategy` which is the strategy for `aten.slice.Tensor`. Previously, using dynamic shapes with slicing would result in ``` File ".../pytorch/torch/distributed/tensor/_ops/_tensor_ops.py", line 348, in gen_slice_strategy assert isinstance(end, int) torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function (*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(s3, 2)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), slice(None, (s77//2), None)), **{}): got AssertionError() ``` Questions before merge: 1. `dim` is still asserted to be int. Is this fine, or is this potentially dynamic as well? 2. I'm using argtype ignore for `normalize_dim`. Should I instead change types for `normalize_dim` and further dependency to be `IntLike` as well? Pull Request resolved: https://github.com/pytorch/pytorch/pull/157953 Approved by: https://github.com/wconstab --- .../tensor/test_dtensor_compile.py | 20 +++++++++++++++++++ torch/distributed/tensor/_ops/_tensor_ops.py | 13 ++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index a26cf5da144ff..54ec52ee32d41 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -276,6 +276,26 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfHpu + def test_dtensor_dynamic_slice(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in DTensor as inputs/outputs and run some tensor computation + def fn(x): + return [ + t.redistribute( + device_mesh=x.device_mesh, placements=[Replicate()] + ).to_local()[0] + for t in torch.tensor_split(x, 2) + ] + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=True) + res = opt_fn(x) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fe957d2ccab6b..a81db1a3b124e 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -4,6 +4,7 @@ from typing import cast, Optional import torch +from torch._prims_common import IntLike from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, @@ -376,14 +377,14 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: start = 0 if end is None or end > input_shape[dim]: end = input_shape[dim] - assert isinstance(start, int) - assert isinstance(end, int) - assert isinstance(step, int) + assert isinstance(start, IntLike) + assert isinstance(end, IntLike) + assert isinstance(step, IntLike) # normalize args - slice_dim = normalize_dim(dim, input_ndim) - start = normalize_dim(start, input_shape[dim]) - end = normalize_dim(end, input_shape[dim]) + slice_dim = normalize_dim(dim, input_ndim) # type: ignore[arg-type] + start = normalize_dim(start, input_shape[dim]) # type: ignore[arg-type] + end = normalize_dim(end, input_shape[dim]) # type: ignore[arg-type] redundant_slice = start == 0 and end == input_shape[dim] and step == 1 From 4486a6dbfd65ef490cfe73e0630929e85f61ee16 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 14 Jul 2025 14:28:26 -0700 Subject: [PATCH 029/457] [DTensor] Fix grouped_mm strategy for invalid stride cases (#158245) local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158245 Approved by: https://github.com/zpcore, https://github.com/XilunWu --- test/distributed/tensor/test_matrix_ops.py | 75 ++++++++++++++---- torch/distributed/tensor/_ops/_matrix_ops.py | 53 ++++++++++++- torch/distributed/tensor/_ops/utils.py | 80 ++++++++++++++------ 3 files changed, 170 insertions(+), 38 deletions(-) diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index d0f8482c0cf57..523908c5e6bc4 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -19,7 +19,12 @@ from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type -from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_WITH_ROCM, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, @@ -508,40 +513,78 @@ def test_tensordot_shampoo(self): @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @with_comms @skip_unless_torch_gpu - def test_grouped_mm(self): + @parametrize( + "kwargs", + [ + { + # 2D x 3D case from MoE layer + "inp_shape": (64, 16), + "w1_shape": (2, 16, 32), + "w2_shape": (2, 32, 16), + "inp_placements": [Replicate()], + "w1_placements": [Shard(2)], + "w2_placements": [Shard(1)], + "expected_comm_counts_fwd": 0, + "expected_comm_counts_bwd": 1, + "expected_out_placements": [Partial()], + }, + { + # Case that would have invalid strides on inp * mat1 when sharded + "inp_shape": (64, 16), + "w1_shape": (2, 16, 16), + "w2_shape": (2, 16, 16), + "inp_placements": [Replicate()], + "w1_placements": [Shard(2)], + "w2_placements": [Shard(1)], + "expected_comm_counts_fwd": 2, + "expected_comm_counts_bwd": 4, + "expected_out_placements": [Replicate()], + }, + ], + ) + def test_grouped_mm(self, kwargs): # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) - # Here we only test the 2D x 3D Tensor Parallel use case in an MoE layer. # More tests need to be added. device_mesh = init_device_mesh(self.device_type, (self.world_size,)) comm_mode = CommDebugMode() dtype = torch.bfloat16 - inp = torch.rand( - 64, 16, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["inp_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) w1 = torch.rand( - 2, 16, 32, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["w1_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) w2 = torch.rand( - 2, 32, 16, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["w2_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32) h = torch._grouped_mm(inp, w1, offs=offs) out = torch._grouped_mm(h, w2, offs=offs) - dist_inp = distribute_tensor(inp, device_mesh, [Replicate()]) + dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"]) # colwise sharded - dist_w1 = distribute_tensor(w1, device_mesh, [Shard(2)]) + dist_w1 = distribute_tensor(w1, device_mesh, kwargs["w1_placements"]) # rowwise sharded - dist_w2 = distribute_tensor(w2, device_mesh, [Shard(1)]) + dist_w2 = distribute_tensor(w2, device_mesh, kwargs["w2_placements"]) dist_offs = distribute_tensor(offs, device_mesh, [Replicate()]) with comm_mode: dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs) dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs) - self.assertEqual(comm_mode.get_total_counts(), 0) - self.assertTrue(dist_out.placements[0].is_partial()) + self.assertEqual( + comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"] + ) + self.assertEqual(dist_out.placements, kwargs["expected_out_placements"]) self.assertEqual(dist_out.full_tensor(), out) out_grad = torch.ones_like(out) @@ -552,15 +595,19 @@ def test_grouped_mm(self): with comm_mode: dist_out.backward(dist_out_grad) - self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_total_counts(), kwargs["expected_comm_counts_bwd"] + ) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], - 1, + kwargs["expected_comm_counts_bwd"], ) self.assertEqual(dist_inp.grad.full_tensor(), inp.grad) self.assertEqual(dist_w1.grad.full_tensor(), w1.grad) self.assertEqual(dist_w2.grad.full_tensor(), w2.grad) +instantiate_parametrized_tests(DistMatrixOpsTest) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b7804d318104d..6b662aca4912a 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -6,7 +6,7 @@ import torch from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpSchema, OpSpec, @@ -24,6 +24,10 @@ prod, register_op_strategy, ) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) from torch.distributed.tensor.placement_types import ( Partial, Placement, @@ -1035,6 +1039,51 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: ] ) + def valid_grouped_mm_strides( + input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] + ) -> bool: + # 1. compute the local-tensor shape/strides given this sharding proposal + # 2. apply the logic from the groped_mm meta function + # UGH the input DTensorSpecs are missing their tensormetas... so i can get them another way + def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: + assert isinstance(spec.output_specs, DTensorSpec) + assert isinstance(spec.output_specs.tensor_meta, TensorMeta) + meta: TensorMeta = spec.output_specs.tensor_meta + local_stride = compute_local_stride(meta.stride, mesh, placements) + local_shape, _ = compute_local_shape_and_global_offset( + meta.shape, mesh, placements + ) + return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) + + mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements) + mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements) + + def check_valid_strides(meta: TensorMeta) -> bool: + # copied from `_meta_grouped_mm_common` in meta_registrations.py + end_dim = len(meta.shape) - 1 + alignment = 16 // meta.dtype.itemsize + if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max( + 1, meta.shape[end_dim - 1] + ): + if not meta.stride[end_dim] % alignment == 0: + return False + elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max( + 1, meta.shape[end_dim] + ): + if not meta.stride[end_dim - 1] % alignment == 0: + return False + else: + return False + return True + + mat1_valid = check_valid_strides(mat1_meta) + mat2_valid = check_valid_strides(mat2_meta) + return mat1_valid and mat2_valid + return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=1 + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=1, + is_valid_strategy_cb=valid_grouped_mm_strides, ) diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 5215795b085d7..f6dd44cdfb08e 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -241,7 +241,36 @@ def expand_to_full_mesh_op_strategy( *, input_index: int = 1, inplace_op: bool = False, + is_valid_strategy_cb: Optional[ + Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool] + ] = None, ) -> OpStrategy: + """ + Convenience function to allow writing a sharding strategy considering only a single mesh dimension, + and have it expanded combinatorically to all mesh dimensions. + + Args: + mesh (DeviceMesh): the device mesh to expand the strategy to + op_schema (OpSchema): the op schema + single_mesh_dim_strategies (list[PlacementList]): the sharding strategies to expand. The outer list is over + different strategies. The inner PlacementList is over the outputs and inputs of the op. If input_index is 1, + a PlacementList looks like [output_placement, input_placement1, input_placement2, ...]. + input_index: the number of outputs of the op, defaults to 1 + inplace_op: whether the op is inplace or not, defaults to False + is_valid_strategy_cb: a callback function to filter out invalid sharding rules, defaults to None. + + Example: Let's say `my_op(tensor_x, tensor_y) - > output_tensor` can support sharding or replicating tensor_x, + but always requires tensor_y to be replicated. We can specify these valid combinations ignoring mesh dims. + Then, we can rely on `expand_to_full_mesh_op_strategy` to create every possible combination of these shardings + over multiple mesh dimensions, filtering out any combinations that are invalid based on the actual mesh dim size. + + single_mesh_dim_strategies = [ + # first strategy: return output sharded on first dim, shard tensor_x on its first dim, replicate tensor_y + [Shard(0), Shard(0), Replicate()] + # second strategy: replicate output, and both inputs + [Replicate(), Replicate(), Replicate()] + ] + """ # Expand the single_mesh_dim_strategies to full mesh dim strategies. all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim @@ -252,6 +281,7 @@ def expand_to_full_mesh_op_strategy( spec_list: list[Optional[DTensorSpec]] = [] for specs in zip(*strategy_comb): if specs[0] is not None: + # TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback spec_list.append(DTensorSpec(mesh, specs)) else: spec_list.append(None) @@ -269,30 +299,36 @@ def expand_to_full_mesh_op_strategy( # input_spec matches the first argument's runtime sharding, otherwise we skip continue - # check inputs shardable - inputs_shardable = all( + output_specs: tuple[Optional[DTensorSpec], ...] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + + # check all inputs are shardable + if not all( is_tensor_shardable(inp.shape, s) for inp, s in zip(input_args_strategy, input_specs) - ) + ): + continue - # only add to the all_strategies list when all inputs are shardable - if inputs_shardable: - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - for input_strategy, input_spec in zip(input_args_strategy, input_specs) - ] - if input_index > 1: - output_specs = tuple(spec_list[:input_index]) - else: - if spec_list[0] is not None: - output_specs = spec_list[0] # type: ignore[assignment] - else: - raise RuntimeError("output spec is None") - strategy = OpSpec( - output_specs=output_specs, - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) + # perform additional op-specific filtering + if is_valid_strategy_cb is not None: + if not is_valid_strategy_cb(input_specs, output_specs): + continue + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + + strategy = OpSpec( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) return OpStrategy(all_strategies) From 12151c96d9202875638ea2c695d5647c38368c46 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 00:00:06 +0000 Subject: [PATCH 030/457] [ONNX] Remove legacy io_adapter (#158255) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158255 Approved by: https://github.com/justinchuby --- torch/onnx/_internal/_exporter_legacy.py | 27 - .../_internal/fx/dynamo_graph_extractor.py | 78 +-- torch/onnx/_internal/io_adapter.py | 652 ------------------ 3 files changed, 3 insertions(+), 754 deletions(-) delete mode 100644 torch/onnx/_internal/io_adapter.py diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index b3150ef9cdeb3..1e6a9df9a9903 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -21,7 +21,6 @@ import torch import torch._ops -from torch.onnx._internal import io_adapter from torch.onnx._internal._lazy_import import onnxscript_apis from torch.onnx._internal.exporter import _constants from torch.onnx._internal.fx import ( @@ -392,8 +391,6 @@ class FXGraphExtractor(abc.ABC): def __init__(self) -> None: super().__init__() - self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter() - self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter() @abc.abstractmethod def generate_fx( @@ -469,28 +466,4 @@ def common_pre_export_passes( if isinstance(original_model, torch.nn.Module): module = passes.RestoreParameterAndBufferNames(module, original_model).run() - # ONNX does not support None inputs. During graph building, all None inputs - # are removed. Here we register this step to input adapter. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) - - # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 - # Dynamo doesn't support non-tensor inputs. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep()) - - # ONNX does not support complex inputs. During graph building, all complex inputs - # are converted to real representation inputs. Here we register this step to - # input/output adapter. - options.fx_tracer.input_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationInputStep() - ) - - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - - # Output post-processing steps should happen after `FlattenOutputStep`. - options.fx_tracer.output_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationOutputStep() - ) - return module diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index b11903619c080..73720ec39d560 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -7,15 +7,12 @@ from __future__ import annotations import contextlib -import functools import inspect from typing import Any, Callable, TYPE_CHECKING import torch._dynamo -import torch.export as torch_export import torch.fx -import torch.onnx -from torch.onnx._internal import _exporter_legacy, io_adapter +from torch.onnx._internal import _exporter_legacy from torch.utils import _pytree as pytree @@ -104,61 +101,6 @@ def model_output_unflatten( ) -class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep): - """Flatten nested collection and custom python types and return a flat list of elements. - - Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary - types via pytree extension. By default this supports many common user defined python - types such as :class:`ModelOutput` from HuggingFace transformers. - - The pytree extension can be customized by passing in a ``_PyTreeExtensionContext`` - object. See :meth:`_PyTreeExtensionContext.register_pytree_node`. - """ - - def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None): - super().__init__() - self._pytree_extension_context = ( - pytree_extension_context or _PyTreeExtensionContext() - ) - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs, under the context of pytree extension.""" - with self._pytree_extension_context: - return super().apply(model_outputs, model=model) - - -def _wrap_model_with_output_adapter( - model: torch.nn.Module | Callable, - output_adapter: DynamoFlattenOutputStep, -) -> Callable: - """Wrap model with output adapter. - - This is a helper function to enable :func:`dynamo.export` on models that produce - custom user defined types outputs. It wraps the model with an output adapter to - convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`. - - The adapting logic is controlled by ``output_adapter``. - - Args: - model: PyTorch model or function. - output_adapter: Output adapter to apply to model output. - Returns: - Wrapped model. - """ - model_func = model.forward if isinstance(model, torch.nn.Module) else model - - # Preserve original function signature. - @functools.wraps(model_func) - def wrapped(*args, **kwargs): - return output_adapter.apply(model_func(*args, **kwargs), model=model) - - return wrapped - - class DynamoExport(_exporter_legacy.FXGraphExtractor): """Generates a FX GraphModule using torch.dynamo.export API Args: @@ -183,12 +125,7 @@ def generate_fx( # `dynamo.export` does not recognize custom user defined classes as output type. # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, # i.e. :class:`torch.Tensor`. - dynamo_flatten_output_step = DynamoFlattenOutputStep() - wrapped_model = _wrap_model_with_output_adapter( - model, dynamo_flatten_output_step - ) - # Record the output adapter step. - self.output_adapter.append_step(dynamo_flatten_output_step) + wrapped_model = model # Translate callable to FX graph. # @@ -209,16 +146,7 @@ def generate_fx( del graph_guard # Unused torch._dynamo.reset() - # Export FX graph to ONNX ModelProto. - self.input_adapter.append_step( - io_adapter.FlattenInputWithTreeSpecValidationInputStep() - ) - - updated_model_args = self.input_adapter.apply( - *model_args, model=model, **model_kwargs - ) - - return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] + return self.pre_export_passes(options, model, graph_module, model_args) # type: ignore[return-value] def pre_export_passes( self, diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py deleted file mode 100644 index 6c414e8d54e78..0000000000000 --- a/torch/onnx/_internal/io_adapter.py +++ /dev/null @@ -1,652 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import Protocol, runtime_checkable - -import torch -import torch.export as torch_export -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - import inspect - from collections.abc import Mapping, Sequence - - -@runtime_checkable -class InputAdaptStep(Protocol): - """A protocol that defines a step in the input adapting process. - - The input adapting process is a sequence of steps that are applied to the - PyTorch model inputs to transform them into the inputs format expected by the - exported ONNX model. Each step takes the PyTorch model inputs as arguments and - returns the transformed inputs. - - This serves as a base formalized construct for the transformation done to model - input signature by any individual component in the exporter. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: ... - - -class InputAdapter: - """A class that adapts the PyTorch model inputs to exported ONNX model inputs format.""" - - def __init__(self, steps: list[InputAdaptStep] | None = None): - self._steps = steps or [] - - def append_step(self, step: InputAdaptStep) -> None: - """Appends a step to the input adapt steps. - - Args: - step: The step to append. - """ - self._steps.append(step) - - def apply( - self, - *model_args, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - **model_kwargs, - ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]: - """Converts the PyTorch model inputs to exported ONNX model inputs format. - - Args: - model_args: The PyTorch model inputs. - model: The PyTorch model. - model_kwargs: The PyTorch model keyword inputs. - Returns: - A sequence of tensors converted from PyTorch model inputs. - """ - args: Sequence[Any] = model_args - kwargs: Mapping[str, Any] = model_kwargs - for step in self._steps: - args, kwargs = step.apply(args, kwargs, model=model) - assert not kwargs - return args - - -@runtime_checkable -class OutputAdaptStep(Protocol): - """A protocol that defines a step in the output adapting process. - - The output adapting process is a sequence of steps that are applied to the - PyTorch model outputs to transform them into the outputs format produced by the - exported ONNX model. Each step takes the PyTorch model outputs as arguments and - returns the transformed outputs. - - This serves as a base formalized construct for the transformation done to model - output signature by any individual component in the exporter. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Any: ... - - -class OutputAdapter: - """A class that adapts the PyTorch model outputs to exported ONNX model outputs format.""" - - def __init__(self, steps: list[OutputAdaptStep] | None = None): - self._steps = steps or [] - - def append_step(self, step: OutputAdaptStep) -> None: - """Appends a step to the output format steps. - - Args: - step: The step to append. - """ - self._steps.append(step) - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[torch.Tensor | int | float | bool | str]: - """Converts the PyTorch model outputs to exported ONNX model outputs format. - - Args: - model_outputs: The PyTorch model outputs. - model: The PyTorch model. - - Returns: - PyTorch model outputs in exported ONNX model outputs format. - """ - for step in self._steps: - model_outputs = step.apply(model_outputs, model=model) - return model_outputs - - -# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 - - -# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. -class _DummyLeaf: # use a class instead. - pass - - -def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: - def replace_list_with_tuple(x: Any) -> Any: - if type(x) is list: - return pytree.tree_map( - replace_list_with_tuple, - tuple(x), - is_leaf=lambda x: type(x) is list, - ) - return x - - dummy_leaf = _DummyLeaf() - dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) - dummy_tree = pytree.tree_map( - replace_list_with_tuple, - dummy_tree, - is_leaf=lambda x: type(x) is list, - ) - return pytree.tree_structure(dummy_tree) - - -def _open_top_level_sequence_if_single_element( - spec: pytree.TreeSpec, -) -> pytree.TreeSpec: - if spec.type in (tuple, list) and spec.num_children == 1: - return spec.children_specs[0] - return spec - - -def _assert_identical_pytree_spec( - spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str -) -> None: - """Assert the two `TreeSpec` objects are identical. - - Args: - spec1: The first `TreeSpec` object. - spec2: The second `TreeSpec` object. - error_message: The error message to raise if the two `TreeSpec` objects are not - identical. - - Raises: - ValueError: If the two `TreeSpec` objects are not identical. - """ - pass_if_any_checks: Sequence[Callable[[], bool]] = [ - lambda: spec1 == spec2, - # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. - lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), - # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. - lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, - lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), - ] - - if not any(check() for check in pass_if_any_checks): - raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.") - - -class BindInputStep(InputAdaptStep): - """Bind the input arguments to the model signature.""" - - def __init__(self, model_signature: inspect.Signature): - self._model_signature = model_signature - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Bind the input arguments to the model signature. - - We hope the input kwargs will be mapped to bound.args after binding. - If not, we will raise an error. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. args is always empty. - - Raises: - ValueError: If there are keyword-only arguments left after binding args and - kwargs to model signature. - """ - bound = self._model_signature.bind(*model_args, **model_kwargs) - bound.apply_defaults() - - # keyword-only arguments are not handled. - # bound.kwargs only contains keyword-only arguments after calling - # bind & apply_defaults, so we raise if it's not empty. - if bound.kwargs: - raise ValueError("Keyword-only arguments are not supported.") - return (), bound.arguments - - -class MergeKwargsIntoArgsInputStep(InputAdaptStep): - """Merge the input kwargs into the input args.""" - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Merge the input kwargs into the input args. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. kwargs is always empty. - """ - return tuple(model_args) + tuple(model_kwargs.values()), {} - - -class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep): - """Append parameters and buffers to model's positional argument list.""" - - def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None: - self.inputs = inputs - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Append model's parameters and buffers into its input. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args + appended inputs and kwargs. - """ - return (*model_args, *self.inputs), model_kwargs - - -class ConvertComplexToRealRepresentationInputStep(InputAdaptStep): - """Convert complex dtype tensors to real representation tensors. - - ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors - to real representation tensors (i.e., float dtype tensors with an extra dimension - representing the real and imaginary parts of the complex number). - - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Convert complex tensors to float tensors. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - """ - return ( - tuple( - torch.view_as_real(arg.resolve_conj()) - if isinstance(arg, torch.Tensor) and arg.is_complex() - else arg - for arg in model_args - ), - model_kwargs, - ) - - -class RemoveNoneInputStep(InputAdaptStep): - """Remove `None` from arguments. - - This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args`` - is flattened, i.e. it does not check `None` inside nested collections. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Remove `None` from arguments. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - - Raises: - ValueError: If `model_kwargs` is not empty. - """ - assert not model_kwargs - return tuple(arg for arg in model_args if arg is not None), {} - - -class RemoveNonTensorInputStep(InputAdaptStep): - """Remove the non-tensor input arguments. - - Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). - - Specifically, it does put the input into graph with an empty node, but consumed by no ones. - The concrete value is embedded into the graph as a constant arg of a target node. Meta - suggests in this case that one should rewrite the model code to make it tensor if the - input value is supposed to change at runtime. We might need to further investigate - the feasibility of that suggestion. - - For example, - - def func(x, b=1.0): - y = x + b - z = y.relu() - return (y, z) - - x = torch.randn(1, 1, 2, dtype=torch.float32) - gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") - - # class GraphModule(torch.nn.Module): - # def forward(self, x, b): - # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) - # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b - # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None - - # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() - # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) - # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) - - Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as - it's ignored in ONNX graph. Thus, we delete the useless input here. - - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Remove Constant from arguments. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - - Raises: - ValueError: If `model_kwargs` is not empty. - """ - assert not model_kwargs - return ( - tuple( - arg - for arg in model_args - if not isinstance(arg, (int, float, bool, str)) - ), - {}, - ) - - -class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep): - """Flatten nested collection types and return a flat list of elements. - - ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, - etc). - - This class stores the `SpecTree` output produced when `adapt` was called the first - time. It then validates the `SpecTree` output produced from later `adapt` calls. - """ - - _spec: pytree.TreeSpec | None = None - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Flatten the model args and kwargs and validate the `SpecTree` output. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the flattened model args and kwargs. The kwargs is empty, because - they are flattened and merged into the args. - - Raises: - ValueError: If the `SpecTree` output produced from the current `model_outputs` - is not identical to the `SpecTree` output produced from the first - `model_outputs` that was passed to this method. - """ - flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs)) - if self._spec is None: - self._spec = spec - else: - _assert_identical_pytree_spec( - self._spec, - spec, - error_message="Model inputs incompatible with the format that was exported. ", - ) - return flattened_args, {} - - -class FlattenOutputStep(OutputAdaptStep): - """Flatten nested collection types and return a flat list of elements. - - ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, - etc). - - NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such - that `SpecTree` can be validate for new model outputs. However, this is not possible - currently because we never have access to real PyTorch model outputs during export. - Only traced outputs may be available, but they are not an accurate reflection of the - original PyTorch model outputs format as they are typically in their own unique format, - depending on the tracing strategy. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - A tuple of the flattened model outputs. - """ - return pytree.tree_leaves(model_outputs) - - -class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep): - """Convert complex dtype tensors to real representation tensors. - - ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors - to real representation tensors (i.e., float dtype tensors with an extra dimension - representing the real and imaginary parts of the complex number). - - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Any: - """Convert float tensors to complex tensors. - - Args: - model_output: The model output. - model: The PyTorch model. - - Returns: - A tuple of the model output. - """ - return [ - torch.view_as_real(output.resolve_conj()) - if isinstance(output, torch.Tensor) and torch.is_complex(output) - else output - for output in model_outputs - ] - - -class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep): - """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation. - - This class stores the `SpecTree` output produced when `adapt` was called the first - time. It then validates the `SpecTree` output produced from later `adapt` calls. - """ - - _spec: pytree.TreeSpec | None = None - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs and validate the `SpecTree` output. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - flattened_outputs: The flattened model outputs. - - Raises: - ValueError: If the `SpecTree` output produced from the current `model_outputs` - is not identical to the `SpecTree` output produced from the first - `model_outputs` that was passed to this method. - """ - flattened_outputs, spec = pytree.tree_flatten(model_outputs) - if self._spec is None: - self._spec = spec - else: - _assert_identical_pytree_spec( - self._spec, - spec, - error_message="Model outputs incompatible with the format that was exported. ", - ) - return flattened_outputs - - -class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep): - """Prepend model parameters, buffers and constants to the user input. - - :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they - must be added to the user input before the model is executed. - - Args: - model: The PyTorch model with embedded parameters and buffers. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Convert complex tensors to float tensors. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - """ - ordered_params = tuple( - model.state_dict[name] # type: ignore[union-attr,index] - for name in model.graph_signature.parameters # type: ignore[union-attr] - ) - non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[arg-type, union-attr] - ordered_buffers = [] - for name in model.graph_signature.buffers: # type: ignore[union-attr] - if name in non_persistent_buffers: - ordered_buffers.append(model.constants[name]) # type: ignore[index, union-attr] - else: - ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] - ordered_constant_tensors = tuple( - model.constants[fqn] # type: ignore[union-attr,index] - for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr] - ) - - # NOTE: calling convention is first params, then buffers, then args as user supplied them. - # See: torch/_functorch/aot_autograd.py#L1034 - updated_args = ( - *ordered_params, - *ordered_buffers, - *ordered_constant_tensors, - *model_args, - ) - if model_kwargs: - return MergeKwargsIntoArgsInputStep().apply( - updated_args, model_kwargs, model=model - ) - return updated_args, {} - - -class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): - """Prepend model's mutated buffers to the user output. - - :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they - must be added to the user output after the model is executed. - - Args: - model: The PyTorch model with mutated buffers. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs and validate the `SpecTree` output. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - flattened_outputs: The flattened model outputs. - """ - - assert isinstance(model, torch_export.ExportedProgram), ( - "'model' must be torch_export.ExportedProgram" - ) - ordered_buffers = tuple( - model.state_dict[name] - if name in model.state_dict - else model.constants[name] - for name in model.graph_signature.buffers_to_mutate.values() - ) - - # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. - updated_outputs = (*ordered_buffers, *model_outputs) - return updated_outputs From 336bff6d58ceb50b12d9d67764fd9f238bc0adb5 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 00:00:06 +0000 Subject: [PATCH 031/457] [ONNX] Remove legacy graph passes (#158256) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158256 Approved by: https://github.com/justinchuby ghstack dependencies: #158255 --- torch/onnx/_internal/_exporter_legacy.py | 25 +-- torch/onnx/_internal/fx/passes/__init__.py | 10 -- torch/onnx/_internal/fx/passes/decomp.py | 87 ---------- .../_internal/fx/passes/functionalization.py | 152 ------------------ torch/onnx/_internal/fx/passes/readability.py | 130 --------------- .../_internal/fx/passes/virtualization.py | 96 ----------- torch/onnx/_internal/onnxruntime.py | 3 - 7 files changed, 1 insertion(+), 502 deletions(-) delete mode 100644 torch/onnx/_internal/fx/passes/decomp.py delete mode 100644 torch/onnx/_internal/fx/passes/functionalization.py delete mode 100644 torch/onnx/_internal/fx/passes/readability.py delete mode 100644 torch/onnx/_internal/fx/passes/virtualization.py diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 1e6a9df9a9903..5447e503801d5 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -439,31 +439,8 @@ def common_pre_export_passes( # TODO: Import here to prevent circular dependency from torch.onnx._internal.fx import passes - # Apply decomposition table to the input graph. - module = passes.Decompose( - fx_module, - options.decomposition_table, # type: ignore[arg-type] - enable_dynamic_axes=options.dynamic_shapes, - allow_fake_constant=options.fake_context is not None, - ).run(*fx_module_args) - - # ONNX does not support views and mutations. - # Functionalize to get a semantically equivalent graph without mutations. - module = passes.Functionalize( - module, - enable_dynamic_axes=options.dynamic_shapes, - allow_fake_constant=options.fake_context is not None, - ).run(*fx_module_args) - - # Input mutations are detected and distilled after `Functionalize` pass. - # Remove them since ONNX inference does not need them. - module = passes.RemoveInputMutation(module).run(*fx_module_args) - # ONNX does not support concept of (implicit) type promotion. # Insert type casts explicitly where needed. - module = passes.InsertTypePromotion(module).run() - - if isinstance(original_model, torch.nn.Module): - module = passes.RestoreParameterAndBufferNames(module, original_model).run() + module = passes.InsertTypePromotion(fx_module).run() return module diff --git a/torch/onnx/_internal/fx/passes/__init__.py b/torch/onnx/_internal/fx/passes/__init__.py index aa04e6beb5f12..d6309b59da10f 100644 --- a/torch/onnx/_internal/fx/passes/__init__.py +++ b/torch/onnx/_internal/fx/passes/__init__.py @@ -1,18 +1,8 @@ -from .decomp import Decompose -from .functionalization import Functionalize, RemoveInputMutation from .modularization import Modularize -from .readability import RestoreParameterAndBufferNames from .type_promotion import InsertTypePromotion -from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder __all__ = [ - "Decompose", "InsertTypePromotion", - "Functionalize", "Modularize", - "MovePlaceholderToFront", - "RemoveInputMutation", - "RestoreParameterAndBufferNames", - "ReplaceGetAttrWithPlaceholder", ] diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py deleted file mode 100644 index 1573264d6fc76..0000000000000 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ /dev/null @@ -1,87 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import contextlib -from typing import Callable, TYPE_CHECKING - -import torch -import torch._ops -from torch._dispatch import python as python_dispatch -from torch._subclasses import fake_tensor -from torch.fx.experimental import proxy_tensor -from torch.onnx._internal.fx import _pass -from torch.onnx._internal.fx.passes import _utils - - -if TYPE_CHECKING: - from collections.abc import Mapping - - import torch.fx - - -class Decompose(_pass.Transform): - def __init__( - self, - module: torch.fx.GraphModule, - decomposition_table: Mapping[torch._ops.OpOverload, Callable], - enable_dynamic_axes: bool, - allow_fake_constant: bool | None = False, - ): - super().__init__(module) - self.decomposition_table = decomposition_table - self.enable_dynamic_axes = enable_dynamic_axes - self.allow_fake_constant = allow_fake_constant - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - assert not kwargs, "kwargs is not supported in Decompose." - - # To preserve stack trace info after `make_fx`. - module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) - - # fake mode use static size to trace the size of tensors. while symbolic - # mode generates aten::sym_size to dynamically trace the size of tensors. - - # e.g. fake mode: - # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20]) - - # e.g. symbolic mode: - # sym_size = torch.ops.aten.sym_size(x, 0) - # sym_size_1 = torch.ops.aten.sym_size(x, 1) - # sym_size_2 = torch.ops.aten.sym_size(x, 2) - # sym_size_3 = torch.ops.aten.sym_size(x, 3) - # mul = sym_size_2 * sym_size_3; sym_size_2 = sym_size_3 = None - # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul]) - - # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. - # TODO: May need revisit for user fake mode export + dynamic shape scenario. - fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode - maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) - if fake_mode is not None: - # Using existing fake mode as context, signal `make_fx` that it does not need - # to create a new fake mode by passing tracing_mode as "real". - tracing_mode = "real" - else: - # Existing fake mode not found, signal `make_fx` to create one. - fake_mode = contextlib.nullcontext() # type: ignore[assignment] - tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" - - # Apply decomposition table to the input graph. - assert fake_mode is not None # for mypy - with ( - fake_tensor.unset_fake_temporarily(), - python_dispatch.enable_python_dispatcher(), - fake_mode, - ): - decomposed_module = proxy_tensor.make_fx( - module, - decomposition_table=self.decomposition_table, - tracing_mode=tracing_mode, - _allow_non_fake_inputs=True, - _allow_fake_constant=bool(self.allow_fake_constant), - )(*maybe_fake_args) - - # Rename placeholder targets to match the original module's signature since - # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). - _utils.replace_placeholder_name_and_target(decomposed_module, self.module) - - return decomposed_module diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py deleted file mode 100644 index fd8d3c7d48ac5..0000000000000 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ /dev/null @@ -1,152 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import contextlib -from typing import Callable - -import torch -import torch._ops -import torch.func -import torch.fx -from torch._subclasses import fake_tensor -from torch.fx.experimental import proxy_tensor -from torch.onnx._internal.fx import _pass -from torch.onnx._internal.fx.passes import _utils -from torch.utils import _pytree as pytree - - -class Functionalize(_pass.Transform): - """Functionalize a GraphModule. - - This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert - a GraphModule into a functional form. The two main functionalities are (copied from - its documentations): - - * ``functionalization`` removes (intermediate) mutations and aliasing from a - function, while preserving the function's semantics. - - * ``functionalization`` also removes mutations (and views) that were performed - on function inputs. However to preserve semantics, functionalize will "fix up" the - mutations after the transform has finished running, by detecting if any tensor inputs - "should have" been mutated, and copying the new data back to the inputs if necessary. - For example, consider:: - - def fn(a, b): - a.add_(b) - return a - - For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just - functionalizing is not enough for preserving the original semantics. A "special" - input mutation step needs to be inserted at the end.:: - - # After functionalization, without input mutation "fix up". - # This is not semantically the same. The variable outside the function call that - # was passed in as `a` is not mutated. - def fn(a, b): - new_a = a + b - return new_a - - # Functionalization with input mutation "fix up" that preserves semantics. - def fn(a, b): - new_a = a + b - - # Copying the new data back to the inputs - a.copy_(new_a) - - return new_a - - For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass. - ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``, - which are not needed for ONNX inference. - """ - - def __init__( - self, - module: torch.fx.GraphModule, - enable_dynamic_axes: bool, - allow_fake_constant: bool | None = False, - ): - super().__init__(module) - self.enable_dynamic_axes = enable_dynamic_axes - self.allow_fake_constant = allow_fake_constant - - def _functionalize(self, function: Callable) -> Callable: - # Working around a dispatcher issue with `torch.func.functionalize` when used - # together with `make_fx`. - # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391 - def wrapped(*inputs): - inputs_functional = pytree.tree_map_only( - torch.Tensor, torch._to_functional_tensor, inputs - ) - torch._enable_functionalization(reapply_views=True) - try: - out = function(*inputs_functional) - finally: - torch._disable_functionalization() - - flat_inputs_functional = pytree.tree_leaves(inputs_functional) - for input_functional in flat_inputs_functional: - if isinstance(input_functional, torch.Tensor): - torch._sync(input_functional) - pytree.tree_map(torch._sync, out) - out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) - return out_unwrapped - - return wrapped - - def _run(self, *args) -> torch.fx.GraphModule: - # To preserve stack trace info after `make_fx`. - module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) - - functionalized_callable = self._functionalize(module) - - # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. - # TODO: May need revisit for user fake mode export + dynamic shape scenario. - fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode - maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) - if fake_mode is not None: - # Using existing fake mode as context, signal `make_fx` that it does not need - # to create a new fake mode by passing tracing_mode as "real". - tracing_mode = "real" - else: - # Existing fake mode not found, signal `make_fx` to create one. - fake_mode = contextlib.nullcontext() # type: ignore[assignment] - tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" - - assert fake_mode is not None # for mypy - with fake_tensor.unset_fake_temporarily(), fake_mode: - graph_module = proxy_tensor.make_fx( - functionalized_callable, - decomposition_table={}, - tracing_mode=tracing_mode, - _allow_non_fake_inputs=True, - _allow_fake_constant=bool(self.allow_fake_constant), - )(*maybe_fake_args) - - # Rename placeholder targets to match the original module's signature since - # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). - _utils.replace_placeholder_name_and_target(graph_module, self.module) - - return graph_module - - -class RemoveInputMutation(_pass.Transform): - """Remove `aten.copy_.default` nodes that mutate module inputs. - - This pass is recommended to be used after ``Functionalization`` pass. - ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph - when it detects mutations to inputs. These nodes are not needed for ONNX export - for inference. They could be useful for training. - """ - - def _run(self, *args) -> torch.fx.GraphModule: - for node in reversed(self.module.graph.nodes): - if ( - node.op == "call_function" - and node.target == torch.ops.aten.copy_.default - and len(node.users) == 0 - and isinstance(node.args[0], torch.fx.Node) - and node.args[0].op == "placeholder" - ): - self.module.graph.erase_node(node) - return self.module diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py deleted file mode 100644 index a14d07b9aa197..0000000000000 --- a/torch/onnx/_internal/fx/passes/readability.py +++ /dev/null @@ -1,130 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -import torch -from torch.onnx._internal.fx import _pass - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -logger = logging.getLogger(__name__) - - -class RestoreParameterAndBufferNames(_pass.Transform): - """Restore parameter and buffer names from original nn.module. - - This pass is useful for readability of the exported ONNX graph. It restores the - parameter and buffer names from the original nn.module. For example, if the original - nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to - `_param_constant9` by FX, this pass will rename it back. - - This pass must be run after `Decompose` pass. Because this pass is expected to be called on - `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers - are registered at root level. - """ - - def __init__( - self, - fx_module: torch.fx.GraphModule, - original_nn_module: torch.nn.Module, - ): - super().__init__(fx_module) - self.original_nn_module = original_nn_module - - def _rename_param_and_buffer( - self, - nodes: Sequence[torch.fx.Node], - new_name: str, - ) -> None: - """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" - assert len(nodes) > 0, "`nodes` cannot be empty" - assert len({node.target for node in nodes}) == 1, ( - "`nodes` must all have same `target`" - ) - old_name = nodes[0].target - assert isinstance(old_name, str), f"Expected str, got type({old_name})" - # Parameter/buffer name cannot contain "." - normalized_name = new_name.replace(".", "/") - attr_value = getattr(self.module, old_name) - setattr(self.module, normalized_name, attr_value) - delattr(self.module, old_name) - for node in nodes: - with self.module.graph.inserting_before(node): - new_node = self.module.graph.get_attr(normalized_name) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - self.module.graph.erase_node(node) - logger.info( - "Renamed 'self.%s' to 'self.%s', " - "normalized from original parameter name '%s'.", - old_name, - normalized_name, - new_name, - ) - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - """Restore parameter and buffer names from original module. - - For each `get_attr` node, if the target is a str representing a parameter or buffer - under `self.module`, we rename the parameter or buffer to its original name. - The parameters and buffers between `self.module` and `self.original_nn_module` refer - to the same objects, allowing us to use it as key to retrieve the original name. - """ - assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" - assert len(kwargs) == 0, ( - "RestoreParameterAndBufferNames does not take any kwargs" - ) - # state_to_readable_name[parameter/buffer] returns the original readable name of - # the parameter/buffer. E.g., "self.linear.weight". - state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} - state_to_readable_name.update( - {v: k for k, v in self.original_nn_module.named_parameters()} - ) - state_to_readable_name.update( - {v: k for k, v in self.original_nn_module.named_buffers()} - ) - - # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name) - # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and - # `new_name` is the new readable name. - old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {} - - for node in self.module.graph.nodes: - if node.op == "get_attr": - assert isinstance(node.target, str), ( - f"Expected str, got type({node.target})" - ) - if node.target.find(".") != -1: - raise RuntimeError( - f"Unexpected target {node.target} in get_attr, found '.' in target. " - f"All parameters and buffers are expected to be registered at root level, " - f"i.e., self.module. " - ) - if node.target in old_name_to_nodes: - # We have already processed this parameter/buffer. - old_name_to_nodes[node.target][0].append(node) - continue - attr_value = getattr(self.module, node.target) - if ( - isinstance(attr_value, (torch.nn.Parameter, torch.Tensor)) - and attr_value in state_to_readable_name - ): - readable_name = state_to_readable_name[attr_value] - old_name_to_nodes[node.target] = ([node], readable_name) - continue - - logger.info( - "Cannot find readable name for self.%s: %s. The name is unchanged.", - node.target, - type(attr_value), - ) - - for nodes, new_name in old_name_to_nodes.values(): - self._rename_param_and_buffer(nodes, new_name) - - return self.module diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py deleted file mode 100644 index 504dea1d84247..0000000000000 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ /dev/null @@ -1,96 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch -from torch.onnx._internal.fx import _pass - - -if TYPE_CHECKING: - import torch.fx - - -class MovePlaceholderToFront(_pass.Transform): - """This pass move all placeholder nodes to the front of the graph node list. - - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - graph_module = self.module - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return graph_module - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - return graph_module - - -class ReplaceGetAttrWithPlaceholder(_pass.Transform): - """Replace get_attr with placeholder. - - The parameters and buffers accessed by the original get_attr are returned; - they are useful when creating random inputs for the modified graph_module. - """ - - _replaced_attrs: tuple[torch.Tensor, ...] | None - - @property - def replaced_attrs(self) -> tuple[torch.Tensor, ...]: - """The list of replaced weight tensors.""" - assert self._replaced_attrs is not None, ( - "Must run ReplaceGetAttrWithPlaceholder first" - ) - return self._replaced_attrs - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - graph_module = self.module - graph = graph_module.graph - replaced_attrs: list[torch.Tensor] = [] - for node in graph.nodes: - if node.op == "get_attr": - replaced_attr: torch.Tensor | None = None - # get_attr could retrieve either parameter or buffer, so - # we need to try both. - try: - replaced_attr = graph_module.get_parameter(node.target) - except AttributeError: - # It's possible that model author use buffer instead of - # parameter to store trainable weights. In this case, - # 1. get_parameter will throw something like - # AttributeError: `bias` is not an nn.Parameter. - # 2. get_buffer should work. - replaced_attr = graph_module.get_buffer(node.target) - - # Reassign op type so that get_attr node becomes placeholder node. - node.op = "placeholder" - # The target name in placeholder must be a valid Python identifier. - # Thus, we replace, e.g., "module.submodule.weight" with - # "module_submodule_weight". - node.target = node.target.replace(".", "_") - # Default value is None. This is needed as long as the "graph_module" - # has optional inputs. Assume the original forward signature is - # def forward(self, x, y=None) - # and the replaced get_attr node has target "z". Then, the modified - # signature should be - # def forward(self, x, y=None, z=None) - # Without the following line, the signature will be - # def forward(self, x, y=None, z) - # , which is not valid Python code. - node.args = (None,) - - replaced_attrs.append(replaced_attr) - - self._replaced_attrs = tuple(replaced_attrs) - - return graph_module diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index b994328fcdd82..f9550d031fdc3 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -944,9 +944,6 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar # It's first time seeing such as graph. Let's make a new session # (type: onnxruntime.InferenceSession) for it. - graph_module = passes.MovePlaceholderToFront( - graph_module, - ).run() # Generate reference outputs. They are used to indicate output # tensors' types and devices when calling ORT. # From 5fb07acbc32874a932cd26087cf752b2f4cc72df Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 00:00:06 +0000 Subject: [PATCH 032/457] [ONNX] Remove legacy modularization (#158257) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158257 Approved by: https://github.com/justinchuby ghstack dependencies: #158255, #158256 --- torch/onnx/_internal/fx/passes/__init__.py | 2 - .../_internal/fx/passes/modularization.py | 857 ------------------ 2 files changed, 859 deletions(-) delete mode 100644 torch/onnx/_internal/fx/passes/modularization.py diff --git a/torch/onnx/_internal/fx/passes/__init__.py b/torch/onnx/_internal/fx/passes/__init__.py index d6309b59da10f..eff83563a5a08 100644 --- a/torch/onnx/_internal/fx/passes/__init__.py +++ b/torch/onnx/_internal/fx/passes/__init__.py @@ -1,8 +1,6 @@ -from .modularization import Modularize from .type_promotion import InsertTypePromotion __all__ = [ "InsertTypePromotion", - "Modularize", ] diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py deleted file mode 100644 index 18a424826bfef..0000000000000 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ /dev/null @@ -1,857 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import abc -import collections -import copy -import operator -from typing import Any, Final, TYPE_CHECKING - -import torch -import torch.fx -from torch.onnx._internal.fx import _pass -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - from collections.abc import Generator, Iterator, Sequence - - -_FX_TRACER_NN_MODULE_META_TYPE = tuple[str, type] -"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" -_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict -"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer.""" - -_DYNAMO_NN_MODULE_META_TYPE = tuple[str, tuple[str, type]] -"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer.""" -_DYNAMO_NN_MODULE_STACK_META_TYPE = dict[str, _DYNAMO_NN_MODULE_META_TYPE] -"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer.""" - - -class _ModuleMeta: - """Meta information about a module. - - This class is used to represent the module information in a more structured way. - It parses raw module information from a single item from - `node.meta["nn_module_stack"].items()`. - - See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and - `from_dynamo_produced_raw_meta` for how to create an instance. - - Attributes: - _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`. - _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`. - _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'. - """ - - _module_class: Final[type | str | None] # type: ignore[misc] - _module_name: Final[str] # type: ignore[misc] - _raw_meta: Final[tuple[Any, Any]] # type: ignore[misc] - - def __init__( - self, - module_name: str, - module_class: type | str | None, - raw_meta: tuple[Any, Any], - ): - self._module_name = module_name - self._module_class = module_class - self._raw_meta = raw_meta - - @property - def module_display_name(self) -> str: - """The display name of the module. - - E.g. `h_1_mlp_c_proj`. - """ - # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'. - name = self.module_name - name = name.removeprefix("L__self___") - return name - - @property - def qualified_module_class_name(self) -> str: - """Qualified name of the module class. - - E.g. `torch_nn_module_sparse_Embedding`. - """ - if self._module_class is None: - return "" - mod_cls = self._module_class - if isinstance(mod_cls, type): - mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__ - return mod_cls.replace(".", "_") - - @property - def module_class_name(self) -> str: - """Name of the module class. - - E.g. `Embedding`. - """ - if self._module_class is None: - return "" - if isinstance(self._module_class, type): - return self._module_class.__name__ - return self._module_class - - @property - def module_name(self) -> str: - """Name of the module. - - E.g. `L__self___h_1_mlp_c_proj`. - """ - return self._module_name - - @property - def raw_meta(self) -> tuple[Any, Any]: - """Returns the raw module meta data. - - I.e. (module_name, node.meta['nn_module_stack'][module_name]). - """ - return self._raw_meta - - def __eq__(self, other: object, /) -> bool: - if not isinstance(other, _ModuleMeta): - return False - return ( - self._module_name == other._module_name - and self._module_class == other._module_class - ) - - def __hash__(self) -> int: - return hash((self._module_name, self._module_class)) - - def __repr__(self) -> str: - return f"ModuleMeta(name={self._module_name}, class={self._module_class})" - - @classmethod - def create_root(cls) -> _ModuleMeta: - """Create an empty module meta representing root module.""" - return _ModuleMeta("", None, ("", None)) - - @classmethod - def from_fx_tracer_produced_raw_meta( - cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE - ) -> _ModuleMeta: - """Create a module meta from raw meta produced by FX symbolic tracer.""" - module_name, module_class = raw_meta - return _ModuleMeta(module_name, module_class, raw_meta) - - @classmethod - def from_dynamo_produced_raw_meta( - cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE - ) -> _ModuleMeta: - """Create a module meta from raw meta produced by FX dynamo tracer.""" - module_name, (_qualified_name, module_class) = raw_meta - return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) - - @classmethod - def from_raw_meta( - cls, - raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE, - ) -> _ModuleMeta: - if ( - isinstance(raw_meta, tuple) - and len(raw_meta) == 2 - and isinstance(raw_meta[1], type) - ): - # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)` - return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta) - if ( - isinstance(raw_meta, tuple) - and len(raw_meta) == 2 - and isinstance(raw_meta[1], tuple) - ): - # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)` - return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta) - raise TypeError( - f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}" - ) - - -class _ModuleStackMeta: - """Meta information about the module call stack. - - This class is used to represent the module call stack information in a more - structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`. - - Example of raw module stack information: - - If produced by dynamo: - - { - 'L__self___h_1': ( - "L['self'].h[1]", - - ), - 'L__self___h_1_attn': ( - "L['self'].h[1].attn", - - ) - } - - If produced by fx.symbolic_trace: - - { - 'h.1': , - 'h.1.attn': - } - """ - - _module_stack: Final[list[_ModuleMeta]] # type: ignore[misc] - - def __init__( - self, - nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE - | _DYNAMO_NN_MODULE_STACK_META_TYPE - | None, - is_exported_program: bool = True, - ): - self._module_stack = [] - if nn_module_stack_meta is None: - return - raw_meta = copy.copy(nn_module_stack_meta) - for item in raw_meta.items(): - # If produced by torch.export.export, there is another call stack layer - # that we need to skip - if is_exported_program: - is_exported_program = False - continue - self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type] - - def __len__(self) -> int: - return len(self._module_stack) - - def __getitem__(self, index: int) -> _ModuleMeta: - return self._module_stack[index] - - def __iter__(self) -> Iterator[_ModuleMeta]: - return iter(self._module_stack) - - def is_empty_or_root(self) -> bool: - return len(self._module_stack) == 0 - - def top(self) -> _ModuleMeta: - """Returns the top module meta in the stack. I.e., the meta for leaf module. - - Example: - - Consider the following module stack: - - stack = [GPT, block1, Attention_1, MLP] - - stack.top() == MLP - """ - if self.is_empty_or_root(): - return _ModuleMeta.create_root() - return self._module_stack[-1] - - def is_superset_of( - self, - module_stack: _ModuleStackMeta, - ) -> bool: - """Determines if self is a superset of the provided module stack. - - I.e., If self includes all elements from the provided module stack, plus additional - elements on top. If self is empty or root, this method always return False. - - Example: - - Consider the following module stack: - - stack_1 = [GPT, block1, Attention_1, MLP] - stack_2 = [GPT, block1] - - stack_1.is_superset_of(stack_2) == True - stack_2.is_superset_of(stack_1) == False - - stack_3 = [GPT, block2, Attention_1] - - stack_1.is_superset_of(stack_3) == False - stack_3.is_superset_of(stack_1) == False - """ - if self.is_empty_or_root(): - return False - - if module_stack.is_empty_or_root() is None: - return True - - if len(self) <= len(module_stack): - return False - - for i, parent_key in enumerate(module_stack): - if self[i] != parent_key: - return False - - return True - - def push(self, module_meta: _ModuleMeta) -> None: - """Pushes a module meta to the stack.""" - self._module_stack.append(module_meta) - - def __eq__(self, other: object, /) -> bool: - if not isinstance(other, _ModuleStackMeta): - return False - return self._module_stack == other._module_stack - - @property - def raw_meta(self) -> dict[str, tuple[str, type]] | None: - """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack'].""" - return { - module_meta.raw_meta[0]: module_meta.raw_meta[1] - for module_meta in self._module_stack - } - - def __repr__(self) -> str: - return f"ModuleStackMeta({self._module_stack})" - - @property - def module_display_name(self) -> str: - """Returns the module display name of the top module.""" - return self.top().module_display_name - - @property - def qualified_module_class_name(self) -> str: - """Returns the qualified module class name of the top module.""" - return self.top().qualified_module_class_name - - @property - def module_class(self) -> type | str | None: - """Returns the module class of the top module.""" - return self.top()._module_class - - -def _module_stack_meta_from_node( - node: torch.fx.Node, is_exported_program: bool = False -) -> _ModuleStackMeta: - return _ModuleStackMeta( - node.meta.get("nn_module_stack"), is_exported_program=is_exported_program - ) - - -def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str: - module_names.setdefault(module_name, 0) - module_names[module_name] += 1 - return f"{module_name}_{module_names[module_name]}" - - -class _IRNode(abc.ABC): - """Base class for IR nodes. - - IR nodes are used for Modularize pass only. They add a layer of abstraction on top of - torch.fx.Node. - - [NOTE: Modularize Pass Implementation] - The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module` - forward call, and then create `call_module` node and sub `fx.GraphModule` from them. - Each `fx.Node` possesses an `nn_module_stack` meta data that contains information - about the module call stack. See `_ModuleStackMeta` for examples. - - Analysis step - ------------- - - Each module call is identified by a set of base stack layers. For each module call, - the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the - same base stack layers. - - For example, - - stack_of_node_0 = [GPT, block0] - stack_of_node_1 = [GPT, block1] - stack_of_node_2 = [GPT, block1, Attention1, MLP] - stack_of_node_3 = [GPT, block1, Attention1] - stack_of_node_4 = [GPT, block2] - - All nodes belong to the `GPT` module call, since they share the base stack layers [GPT]. - [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base - stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0] - for `GPT.block0`, and [node_4] for `GPT.block2` respectfully. - - After the analysis step, a hierarchical representation is generated. - - For above example, the representation is: - - _ModuleNode(GPT) - _ModuleNode(block0) - _LeafNode(node_0) - _ModuleNode(block1) - _LeafNode(node_1) - _ModuleNode(Attention1) - _ModuleNode(MLP) - _LeafNode(node_2) - _LeafNode(node_3) - _ModuleNode(block2) - _LeafNode(node_4) - - Construction step - ----------------- - - The second step is to build the actual `call_module` node and the sub `fx.GraphModule`. - This is done recursively from the leaf `_ModuleNode` to the root. - - For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair - is generated from `_ModuleNode(MLP)`. - - fx.GraphModule(GPT.block1.Attention1.MLP) - graph: - node_2 - - new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)` - - Next, the `GPT.block1.Attention1` submodule is built. Below is generated from - `_ModuleNode(Attention1)`. - - fx.GraphModule(GPT.block1.Attention1) - graph: - new_mlp_node - node_3 - - new_attention1_node = `call_module[GPT.block1.Attention1](...)` - - Until every submodule is built, the new modularized `fx.GraphModule` is generated. - - Alternatives - ------------ - - The current algorithm adopts a top down approach. A bottom up approach is similar. - In contrast to these two, an alternative flat order approach is also possible, where - each node is traversed and copied to the corresponding submodule. - - The advantage of the current approach lies in the encapsulation of the fx.GraphModule - construction for each individual submodule within a single `build_module` method, which - can be called separately once the analysis phase is completed, making debugging more - convenient. - - Regarding construction step, an alternative implementation is to utilize `fx.Interpreter` - for traversing all the nodes under the flattened root module and copying the nodes - into their respective submodule under construction. This approach is not adopted because - - 1. It uses the flat order approach discussed above. This means one cannot individually - construct a submodule and examine it while debugging. - - 2. The graph execution functionality of `fx.Interpreter` is not necessary for the - purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect - as a for loop over all the nodes. - """ - - @property - @abc.abstractmethod - def stack_meta(self) -> _ModuleStackMeta: - """The module stack meta data associated with this node.""" - ... - - @property - @abc.abstractmethod - def stack_trace(self) -> str | None: - """The stack trace associated with this node.""" - ... - - -class _ModuleNode(_IRNode): - """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule. - - This class encapsulates metadata and provides building block methods to construct this - layered abstraction from a sequence of flat fx.Nodes. - - Attributes: - - _stack_meta: Metadata of the module stack. - - _nodes: List of IR nodes in the module. - - _reference_root_module: Reference to the root flat fx.GraphModule instance. - """ - - def __init__( - self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta - ): - self._stack_meta = stack_meta - self._nodes: list[_IRNode] = [] - self._reference_module = reference_root_module - - @property - def stack_meta(self) -> _ModuleStackMeta: - return self._stack_meta - - @property - def stack_trace(self) -> str | None: - assert self._nodes - return self._nodes[0].stack_trace - - def __str__(self) -> str: - return f"ModuleNode({self._stack_meta})" - - def is_same_module_as(self, node: _IRNode) -> bool: - """Determines if the provided node pertains to the same module as this node.""" - return self.stack_meta == node.stack_meta - - def is_parent_module_of(self, node: _IRNode) -> bool: - """Determines if this node represents a parent module of the provided node.""" - return node.stack_meta.is_superset_of(self.stack_meta) - - def add_leaf_node(self, leaf_node: _LeafNode) -> None: - """Adds a leaf node to the module. - - The leaf node must belong to the same or a child module. This method will recursively - construct _ModuleNode instance based on the stack_meta information of the leaf node. - """ - if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module": - self._nodes.append(leaf_node) - elif leaf_node.fx_op == "placeholder": - # Although the original placeholder has empty nn_module_stack, the placeholder lifted - # from get_attr nodes by exported program has their original nn_module_stack. Here - # we need to avoid them building submodule. - self._nodes.append(leaf_node) - elif self.is_parent_module_of(leaf_node): - # This node belongs in a submodule. - # Check if the last node is a submodule and if it is the parent of this node. - last_node = self._nodes[-1] if self._nodes else None - if isinstance(last_node, _ModuleNode) and ( - last_node.is_parent_module_of(leaf_node) - or last_node.is_same_module_as(leaf_node) - ): - # This node belongs to the last_node. - last_node.add_leaf_node(leaf_node) - else: - # Create a new SubmoduleNode for the immediate child module of the current - # module. The leaf node may be a grandchild of the current module. - # Example: - # self.stack_meta = [A, B, C] - # leaf_node.stack_meta = [A, B, C, D, E, F] - # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it. - stack_meta = copy.deepcopy(self.stack_meta) - stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)]) - last_node = _ModuleNode( - self._reference_module, - stack_meta, - ) - self._nodes.append(last_node) - last_node.add_leaf_node(leaf_node) - else: - raise AssertionError( - f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module " - f"{self._stack_meta}." - ) - - def fx_nodes(self) -> Generator[torch.fx.Node, None, None]: - """Returns an iterator for the sequence of fx nodes this instance holds.""" - for node in self._nodes: - if isinstance(node, _ModuleNode): - yield from node.fx_nodes() - else: - assert isinstance(node, _LeafNode) - yield node.fx_node - - def module_inputs(self) -> Sequence[torch.fx.Node]: - """Extract module inputs from the sequence of fx nodes this instance holds. - - All node args that are produced by nodes outside of the module are considered module - inputs. The order of returned module inputs is the same as the their use order. - - ### Known limitations - - The original ordering of module inputs is not preserved. There is no meta information - to be found from the `fx.GraphModule` that can be used to recover the original ordering. - - Returns: - Sequence of module inputs. - """ - nodes = list(self.fx_nodes()) - assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." - module_inputs: dict[torch.fx.Node, None] = {} - node_set: set[torch.fx.Node] = set(nodes) - - def _extract_arg_if_node_outside_module(arg: Any): - if isinstance(arg, torch.fx.Node) and arg not in node_set: - module_inputs[arg] = None - - for node in nodes: - pytree.tree_map(_extract_arg_if_node_outside_module, node.args) - pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs) - return list(module_inputs.keys()) - - def module_outputs(self) -> Sequence[torch.fx.Node]: - """Extract module outputs from the sequence of fx nodes this instance holds. - - All nodes that are used by nodes outside of the module are considered module - outputs. The order of returned module outputs is the same as the their creation order. - - ### Known limitations - - The original ordering of module outputs is not preserved. There is no meta information - to be found from the `fx.GraphModule` that can be used to recover the original ordering. - - Returns: - Sequence of module outputs. - """ - nodes = list(self.fx_nodes()) - assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." - # Need ordered set. Emulate with dict. - module_outputs: dict[torch.fx.Node, None] = {} - node_set: set[torch.fx.Node] = set(nodes) - - for node in nodes: - if any(user not in node_set for user in node.users): - module_outputs[node] = None - return list(module_outputs.keys()) - - def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule: - """ - Constructs the fx.GraphModule for this node, registering submodules as necessary. - - Args: - module_names: A dictionary of module names and their counts. This is used to - generate unique module names for submodules. This should be an empty - dictionary when the method is called on a root module. - """ - module_class_name = self._stack_meta.qualified_module_class_name - fx_graph = torch.fx.Graph() - copy_env: dict[torch.fx.Node, torch.fx.Node] = {} - - def _arg_transform(node: torch.fx.Node) -> torch.fx.Node: - return copy_env[node] - - ref_inputs = self.module_inputs() - for node in ref_inputs: - copy_env[node] = fx_graph.placeholder(node.name, node.type) - copy_env[node].meta = copy.copy(node.meta) - - for ir_node in self._nodes: - if isinstance(ir_node, _LeafNode): - fx_node = ir_node.fx_node - copy_env[fx_node] = fx_graph.node_copy( - fx_node, arg_transform=_arg_transform - ) - continue - - assert isinstance(ir_node, _ModuleNode) - # Create fx.GraphModule for child submodule. - submodule = ir_node.build_module(module_names) - ref_submodule_inputs = ir_node.module_inputs() - ref_submodule_outputs = ir_node.module_outputs() - unique_submodule_name = _get_unique_module_name( - module_names, ir_node.stack_meta.module_display_name - ) - # Link the newly generated sub fx.GraphModule with the root reference module. - # This step is essential to meet the needs of the subsequent fx.GraphModule initialization - # for the fx.GraphModule being created by this method. - # The initialization of fx.GraphModule will replicate all necessary attributes from a reference - # fx.GraphModule for the fx.Graph. While the root reference module possesses all - # parameters and buffers, it does not include the newly created sub fx.GraphModule. - # Therefore, it's necessary to register it under the root reference at this stage. - self._reference_module.add_submodule(unique_submodule_name, submodule) - - # create call_module fx.Node - submodule_node = fx_graph.call_module( - unique_submodule_name, - tuple(_arg_transform(node) for node in ref_submodule_inputs), - ) - if len(ref_submodule_outputs) > 1: - # Module node has multiple output. Create 'getitem' node for each output. - submodule_node.meta["val"] = tuple( - ref_output.meta.get("val") for ref_output in ref_submodule_outputs - ) - for i, ref_output in enumerate(ref_submodule_outputs): - getitem_node = fx_graph.call_function( - operator.getitem, - args=(submodule_node, i), - type_expr=ref_output.type, - ) - getitem_node.meta = copy.copy(ref_output.meta) - # Make a copy for "nn_module_stack" since the current module will be - # popped from the stack for this 'getitem' node. - getitem_node.meta["nn_module_stack"] = copy.copy( - ref_output.meta["nn_module_stack"] - ) - # The node is associated with the parent module. - getitem_node.meta["nn_module_stack"].popitem() - copy_env[ref_output] = getitem_node - else: - # Module node has single output. Use module node directly. - copy_env[ref_submodule_outputs[0]] = submodule_node - submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta) - - # Update meta for new call_module node. - if (stack_trace := ir_node.stack_trace) is not None: - submodule_node.meta["stack_trace"] = stack_trace - raw_module_stack_meta = ir_node.stack_meta.raw_meta - assert raw_module_stack_meta is not None - submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta) - # The node is associated with the parent module. - submodule_node.meta["nn_module_stack"].popitem() - - new_nodes = fx_graph.nodes - # Skip if the last node is already 'output'. This is the case for root module. - # Otherwise create an 'output' node for the inferred outputs. - if next(iter(reversed(new_nodes))).op != "output": - ref_submodule_outputs = self.module_outputs() - new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()] - node = fx_graph.output( - new_outputs[0] if len(new_outputs) == 1 else new_outputs - ) - - graph_module = torch.fx.GraphModule( - self._reference_module, fx_graph, module_class_name - ) - if (module_class := self._stack_meta.module_class) is not None: - graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta( - _pass.PackageInfo.from_python_class(module_class) - ) - return graph_module - - -class _LeafNode(_IRNode): - """Representing a single fx.Node.""" - - def __init__(self, node: torch.fx.Node, is_exported_program: bool = False): - self._node = node - self._stack_meta = _module_stack_meta_from_node( - node, is_exported_program=is_exported_program - ) - - @property - def fx_op(self) -> str: - """Syntax sugar for self.fx_node.op.""" - return self._node.op - - @property - def fx_node(self) -> torch.fx.Node: - """Returns the fx.Node this instance represents.""" - return self._node - - @property - def stack_meta(self) -> _ModuleStackMeta: - """Returns the module stack meta data associated with this node.""" - return self._stack_meta - - @property - def stack_trace(self) -> str | None: - """Returns the stack trace associated with this node.""" - return self.fx_node.meta.get("stack_trace") - - def __str__(self) -> str: - return f"LeafNode({self._node})" - - -class Modularize(_pass.Transform): - """Transforms a flattened `fx.GraphModule` into a modular structure. - - In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as - a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same - `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or - directly generated by `torch._dynamo.export` with torch.nn.Module. - - This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong - to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the - sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with - the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a - submodule of the new `fx.GraphModule`. - - The process is done based on information from the `nn_module_stack` metadata of each node, i.e. - `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation]. - - An fx submodule under this context can typically be interpreted in three different ways: - - 1. As an embodiment of an nn.Module class, which is considered stateless. - Its execution path can vary depending on the configuration of module initialization, - which should also be part of the inputs. - - 2. As a representation of an nn.Module instance. It maintains the state initialized in the module. - The execution path can vary based on actual input data. - - 3. As a captured call of an nn.Module instance, where the execution path - is set. - - The generality decreases along this list. Within the scope of this function, the pass - creates fx submodules according to the third interpretation. - - The first interpretation is the most general case. It requires complex analysis and additional - metadata and code information to construct its general form. Consider an example nn.Module - that generates arbitrary submodules based on an initialization configuration file. It's impractical - to extract this logic for the generated fx submodule to function with arbitrary configuration. - - The second interpretation demands less analysis and is sturdier than the - first. In most use cases, it's equivalent to the third. It only differs in exceptional situations - where a complex nn.Module instance is called multiple times, each with a different set of inputs - leading to a unique execution branching path. - - The third interpretation is the most specific scenario. It necessitates the minimum - analysis and creates the most stable representation. The drawback is that it - generates more redundancy than the other two methods. If needed, a subsequent post-processing - pass can be applied to consolidate completely identical functions and reduce duplication. - - ### Known constraints - Two successive calls to the same module instance will be conflated. They are indistinguishable. - This is due to limitations of the current fx metadata "nn_module_stack". - - [NOTE: Modularize pass ordering] - This pass groups fx nodes into subgraphs that reside within the `call_module` fx node. - Other fx passes (including some outside the exporter) might not recognize `call_module`. - They may assume that all nodes are flattened. Hence it is recommended to invoke this pass - as the last pre onnx export fx pass. If not for this consideration, this operation could - potentially be relocated anywhere earlier in the pipeline. - - Example: - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import torch - >>> from torch.onnx._internal.fx import passes - >>> - >>> class CustomModule(torch.nn.Module): - >>> def __init__(self) -> None: - >>> super().__init__() - >>> self.embedding = torch.nn.Embedding(10, 32) - >>> self.relu = torch.nn.ReLU() - >>> - >>> def forward(self, x): - >>> out = self.embedding(x) - >>> out = self.relu(out) - >>> return out - >>> - >>> class TestModule(torch.nn.Module): - >>> def __init__(self) -> None: - >>> super().__init__() - >>> self.layer = CustomModule() - >>> self.linear = torch.nn.Linear(32, 10) - >>> - >>> def forward(self, x): - >>> out = self.layer(x) - >>> out = self.linear(out) - >>> return out - >>> - >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)( - ... torch.tensor([0, 1, 2]) - ... ) - >>> gm.print_readable() - - >>> gm = passes.Modularize( - ... gm, - ... ).run() - >>> gm.print_readable() - - """ - - def __init__( - self, - module: torch.fx.GraphModule, - is_exported_program: bool = False, - ): - super().__init__(module) - self.module = module - self.is_exported_program = is_exported_program - - def _run(self) -> torch.fx.GraphModule: - # DCE to remove unused nodes. - # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule - # outputs. But since it is unused, we can just remove it. - self.module.graph.eliminate_dead_code() - - reference_module = torch.fx.GraphModule(self.module, self.module.graph) - root_module_node = _ModuleNode( - reference_module, - _ModuleStackMeta( - nn_module_stack_meta=None, is_exported_program=self.is_exported_program - ), - ) - for fx_node in self.module.graph.nodes: - root_module_node.add_leaf_node( - _LeafNode(fx_node, is_exported_program=self.is_exported_program) - ) - return root_module_node.build_module({}) From a2ad16be72bf989c01b96c4c56b1c108a71c087f Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 00:00:06 +0000 Subject: [PATCH 033/457] [ONNX] Remove legacy Dort tests (#158294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158294 Approved by: https://github.com/justinchuby ghstack dependencies: #158255, #158256, #158257 --- .../test_dynamo_with_onnxruntime_backend.py | 849 ------------------ 1 file changed, 849 deletions(-) delete mode 100644 test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py deleted file mode 100644 index 2e47e48f140eb..0000000000000 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ /dev/null @@ -1,849 +0,0 @@ -# Owner(s): ["module: onnx"] -from __future__ import annotations - -import contextlib -import copy -import dataclasses -import os -import sys -import unittest -from pathlib import Path - -import onnxruntime -from parameterized import parameterized - -import torch -import torch._dynamo.backends.registry -from torch import nn -from torch.onnx import ( - _OrtBackend as OrtBackend, - _OrtBackendOptions as OrtBackendOptions, -) -from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNNModuleInlined - - -sys.path.append(str(Path(__file__).absolute().parents[1])) - -import onnx_test_common - - -def make_aot_ort(): - ort_backend = OrtBackend(options=OrtBackendOptions()) - return ort_backend, ort_backend - - -class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime): - def setUp(self): - super().setUp() - torch._dynamo.reset() - OrtBackend.clear_cached_instances() - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - OrtBackend.clear_cached_instances() - - def test_get_ort_device_type(self): - from onnxruntime.capi import _pybind_state as ORTC - - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"), - ORTC.OrtDevice.cuda(), - ) - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"), - ORTC.OrtDevice.cpu(), - ) - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("maia"), - ORTC.OrtDevice.npu(), - ) - - def test_torch_compile_backend_registration(self): - self.assertIn("onnxrt", torch._dynamo.backends.registry.list_backends()) - backend = torch._dynamo.backends.registry.lookup_backend("onnxrt") - self.assertEqual(backend.__module__, "torch.onnx._internal.onnxruntime") - - def _test_torch_compile_backend_caching_assert_reused( - self, options: OrtBackendOptions - ): - self.assertFalse(OrtBackend.get_cached_instances()) # assert setUp/tearDown - new_backend = OrtBackend.get_cached_instance_for_options(options) - reused_backend = OrtBackend.get_cached_instance_for_options(options) - self.assertEqual(len(OrtBackend.get_cached_instances()), 1) - self.assertIs(reused_backend, new_backend) - if options is None or options.ort_session_options is None: - # OrtBackendOptions.ort_session_options is a pybind11 object that - # cannot be pickled via dataclasses.asdict - self.assertEqual( - new_backend, - OrtBackend.get_cached_instance_for_options( - dataclasses.asdict(options) if options else None - ), - ) - - @parameterized.expand( - [ - (None,), - (OrtBackendOptions(),), - (OrtBackendOptions(use_aot_autograd=True),), - (OrtBackendOptions(use_aot_autograd=False),), - (OrtBackendOptions(preallocate_output=True),), - (OrtBackendOptions(preallocate_output=False),), - (OrtBackendOptions(infer_execution_providers=True),), - (OrtBackendOptions(infer_execution_providers=False),), - (OrtBackendOptions(preferred_execution_providers=["A", "B", "C"]),), - ( - OrtBackendOptions( - preferred_execution_providers=["A", "B", ("C", {"option": "value"})] - ), - ), - (OrtBackendOptions(default_execution_providers=["Something"]),), - (OrtBackendOptions(),), - ] - ) - def test_torch_compile_backend_caching_assert_reused( - self, options: OrtBackendOptions - ): - self._test_torch_compile_backend_caching_assert_reused(options) - - @parameterized.expand( - [ - (OrtBackendOptions(ort_session_options=onnxruntime.SessionOptions()),), - ] - ) - def test_torch_compile_backend_caching_assert_not_reused( - self, options: OrtBackendOptions - ): - with self.assertRaises(AssertionError): - self._test_torch_compile_backend_caching_assert_reused(options) - - def _test_model_numerically( - self, - model, - dynamo_backend, - example_args_collection, - fullgraph: bool = False, - test_backward: bool = False, - atol: float = 1e-5, - rtol: float = 1e-6, - ): - """Run original and compiled model and compare the results. - - Args: - model: The model to test. - dynamo_backend: The dynamo backend to use. Here we use string `onnxrt` or - the first returned value of `make_aot_ort()`. - example_args_collection: A tuple of example arguments to test. E.g., - ( - (torch.randn(2), torch.randn(2)), - (torch.randn(4), torch.randn(4)), - ) - if you want to test - model(torch.randn(2), torch.randn(2)) and - model(torch.randn(4), torch.randn(4)) - . - """ - compiled_model = torch.compile( - model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), - backend=dynamo_backend, - dynamic=True, - fullgraph=fullgraph, - ) - - for example_args in example_args_collection: - baseline_result = model(*example_args) - result = compiled_model(*example_args) - if isinstance(baseline_result, torch.Tensor): - torch.testing.assert_close( - baseline_result, result, atol=atol, rtol=rtol - ) - if test_backward: - baseline_result.sum().backward() - result.sum().backward() - for baseline_param, param in zip( - model.parameters(), compiled_model.parameters() - ): - torch.testing.assert_close( - baseline_param.grad, param.grad, atol=atol, rtol=rtol - ) - else: - assert test_backward is False, ( - "Calculating backward with multiple outputs is not supported yet." - ) - for baseline_elem, result_elem in zip(baseline_result, result): - torch.testing.assert_close( - baseline_elem, result_elem, atol=atol, rtol=rtol - ) - - def _assert_counting_information( - self, - ort_backend: OrtBackend, - # Number of session runs. - # If there is no graph break, this should be the same as - # total number of forward calls. - expected_execution_count: int, - # Number of GraphModule's cached. - # With one graph break, a model will be mapped - # to two GraphModule's. - number_of_cached_graph_modules: int, - # Number of ONNX models cached for each GraphModule, - # number_of_exported_onnx_models[i] contains # of ONNX models exported from - # the i-th element (type: torch.fx.GraphModule) in - # OrtBackend._all_ort_execution_info.execution_info_per_graph_module.values(). - number_of_exported_onnx_models_for_all_graph_modules: tuple[int, ...], - ): - self.assertEqual(expected_execution_count, ort_backend.execution_count) - self.assertEqual( - len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), - number_of_cached_graph_modules, - ) - self.assertEqual( - len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), - len(number_of_exported_onnx_models_for_all_graph_modules), - ) - for ( - onnx_info, - expected_number_of_onnx_models, - ) in zip( - ort_backend._all_ort_execution_info.execution_info_per_graph_module.values(), - number_of_exported_onnx_models_for_all_graph_modules, - ): - self.assertEqual(len(onnx_info), expected_number_of_onnx_models) - - def _assert_dynamic_input_and_output_shapes_in_all_onnx_models(self, backend): - for ( - onnx_session_infos - ) in backend._all_ort_execution_info.execution_info_per_graph_module.values(): - for onnx_session_info in onnx_session_infos: - inputs_have_dynamic_shapes = False - for input in onnx_session_info.input_value_infos: - if hasattr(input.type, "tensor_type") and hasattr( - input.type.tensor_type, "shape" - ): - for dim in input.type.tensor_type.shape.dim: - inputs_have_dynamic_shapes = ( - inputs_have_dynamic_shapes or hasattr(dim, "dim_param") - ) - output_have_dynamic_shapes = False - for output in onnx_session_info.output_value_infos: - if hasattr(output.type, "tensor_type") and hasattr( - output.type.tensor_type, "shape" - ): - for dim in output.type.tensor_type.shape.dim: - output_have_dynamic_shapes = ( - output_have_dynamic_shapes or hasattr(dim, "dim_param") - ) - self.assertTrue(inputs_have_dynamic_shapes) - self.assertTrue(output_have_dynamic_shapes) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_elementwise_function_single_output(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10) - ) - - def elementwise_model(x: torch.Tensor): - y = x.relu() - z = y.sigmoid() - return z - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - # This will use the global ONNXRuntime backend registered - # in Dynamo to compile the tested model. - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - elementwise_model, - local_aot_ort, - example_args_collection, - ) - - # We can only check local backend's counting information - # since global backend's counting information comes from - # all compiled models. - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - # OrtBackend._ort_acclerated_call should have been called 5 times because - # we have 5 different batch sizes to test. - expected_execution_count=len(example_args_collection), - # Since this local_ort only compiled one function, - # there should be only one GraphModule in its cached. - number_of_cached_graph_modules=1, - # Since dynamic shape is enabled, we should only have one ONNX model - # to support different batch sizes. - number_of_exported_onnx_models_for_all_graph_modules=(1,), - ) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_elementwise_function_multiple_output(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8) - ) - - def elementwise_model_with_multiple_outputs(w: torch.Tensor): - x = w + w - y = x.relu() - z = y * y - return x, y, z - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - elementwise_model_with_multiple_outputs, - local_aot_ort, - example_args_collection, - ) - - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - expected_execution_count=len(example_args_collection), - number_of_cached_graph_modules=1, - number_of_exported_onnx_models_for_all_graph_modules=(1,), - ) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_mlp_with_local_backend(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) - ) - - class MLP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = nn.Linear(2, 4, bias=True) - self.fc2 = nn.Linear(4, 2, bias=True) - - def forward(self, tensor_x: torch.Tensor): - tensor_x = self.fc1(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - tensor_x = self.fc2(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - return tensor_x - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - MLP(), - local_aot_ort, - example_args_collection, - ) - - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - # OrtBackend._ort_acclerated_call should have been called 5 times because - # we have 5 different batch sizes to test. - expected_execution_count=len(example_args_collection), - # Since this local_ort only compiled one function, there should be only two - # GraphModule's in its cached. One for batch sizes 2, 4, 6, 8 and the other - # for batch size 1. - number_of_cached_graph_modules=2, - # Since dynamic shape is enabled, we should only have one ONNX model - # to support different batch sizes. - number_of_exported_onnx_models_for_all_graph_modules=(1, 1), - ) - - @parameterized.expand( - [ - (True, True), - (True, False), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_attention_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import ( # noqa: F811 - LlamaAttention, - ) - - hidden_size = 16 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=hidden_size, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - class LlamaAttentionWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - try: - # New version of LlamaAttention has layer_idx argument. - self.attention = LlamaAttention(config, layer_idx=0) - except TypeError: - # Fall back to old version of LlamaAttention. - self.attention = LlamaAttention(config) - - def forward(self, hidden_states, attention_mask, position_ids): - attn_output, _, _ = self.attention( - hidden_states, attention_mask, position_ids - ) - return attn_output - - def generate_example_inputs(batch: int, seq: int, hidden_size: int): - # shape: batch x seq x hidden_size - hidden_state = torch.randn(batch, seq, hidden_size) - # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] - # shape: batch x 1 x seq x seq - attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - - return hidden_state, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8, hidden_size), - generate_example_inputs(4, 7, hidden_size), - generate_example_inputs(9, 15, hidden_size), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaAttentionWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - - execution_count = len(example_args_collection) * number_of_captured_graphs - self._assert_counting_information( - local_ort, - # Number of InferenceSession runs. - expected_execution_count=execution_count, - # Number of GraphModule's seen by ORT. - number_of_cached_graph_modules=number_of_captured_graphs, - # Number of InferenceSession's created per GraphModule. - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True, False), - (True, True), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_decoder_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import ( # noqa: F811 - LlamaDecoderLayer, - ) - - hidden_size = 16 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=hidden_size, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - class LlamaDecoderWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - try: - # New version of LlamaDecoderLayer has layer_idx argument. - self.decoder = LlamaDecoderLayer(config, layer_idx=0) - except TypeError: - # Fall back to old version of LlamaDecoderLayer. - self.decoder = LlamaDecoderLayer(config) - - def forward(self, hidden_states, attention_mask, position_ids): - (decoder_output,) = self.decoder( - hidden_states, attention_mask, position_ids - ) - return decoder_output - - def generate_example_inputs(batch: int, seq: int, hidden_size: int): - # shape: batch x seq x hidden_size - hidden_state = torch.randn(batch, seq, hidden_size) - # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] - # shape: batch x 1 x seq x seq - attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - return hidden_state, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8, hidden_size), - generate_example_inputs(4, 7, hidden_size), - generate_example_inputs(9, 15, hidden_size), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaDecoderWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - - execution_count = len(example_args_collection) * number_of_captured_graphs - - self._assert_counting_information( - local_ort, - expected_execution_count=execution_count, - number_of_cached_graph_modules=number_of_captured_graphs, - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True, False), - (True, True), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import LlamaModel # noqa: F811 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=16, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - config._attn_implementation = "eager" - - class LlamaModelWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - self.llama = LlamaModel(config) - - def forward(self, input_ids, attention_mask, position_ids): - decoder_output = self.llama( - input_ids, attention_mask, position_ids, return_dict=False - ) - return decoder_output[0] - - def generate_example_inputs(batch: int, seq: int): - # shape: batch x seq x hidden_size - input_ids = torch.randint(0, 7, size=(batch, seq), dtype=torch.int64) - # Usually, its shape is a tensor with shape batch x seq x seq. - # However, to bypass some control flow in the model, we use None. - attention_mask = None - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - return input_ids, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8), - generate_example_inputs(4, 7), - generate_example_inputs(9, 15), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaModelWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - atol=1e-4, - rtol=1e-4, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - execution_count = len(example_args_collection) * number_of_captured_graphs - self._assert_counting_information( - local_ort, - expected_execution_count=execution_count, - number_of_cached_graph_modules=number_of_captured_graphs, - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_dump_model(self, test_local_backend: bool): - @contextlib.contextmanager - def onnxrt_dump_path(path): - key = "ONNXRT_DUMP_PATH" - before = os.environ.get(key, None) - os.environ[key] = path - yield - if before is None: - del os.environ[key] - else: - os.environ[key] = before - - example_args_collection = tuple( - (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) - ) - - class MLP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = nn.Linear(2, 4, bias=True) - self.fc2 = nn.Linear(4, 2, bias=True) - - def forward(self, tensor_x: torch.Tensor): - tensor_x = self.fc1(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - tensor_x = self.fc2(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - return tensor_x - - if test_local_backend: - local_aot_ort, _ = make_aot_ort() - else: - local_aot_ort, _ = "onnxrt", None - - prefix = f"test_dump_model_{'local' if test_local_backend else 'onnxrt'}_" - expected = f"{prefix}0.onnx" - expected_graph = f"{prefix}0.txt" - if os.path.exists(expected): - os.remove(expected) - if os.path.exists(expected_graph): - os.remove(expected_graph) - not_expected = f"{prefix}1.onnx" - self.assertFalse(os.path.exists(not_expected)) - - model = MLP() - compiled_model = torch.compile( - model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), - backend=local_aot_ort, - dynamic=True, - ) - - self.assertFalse(os.path.exists(expected)) - self.assertFalse(os.path.exists(not_expected)) - - with onnxrt_dump_path(prefix): - example_args = example_args_collection[0] - compiled_model(*example_args) - self.assertTrue(os.path.exists(expected)) - self.assertTrue(os.path.exists(expected_graph)) - self.assertFalse(os.path.exists(not_expected)) - - compiled_model(*example_args) - self.assertTrue(os.path.exists(expected)) - self.assertFalse(os.path.exists(not_expected)) - - @unittest.skipIf(not torch.cuda.is_available(), "No CUDA to run mix devicei nputs") - def test_mix_device_inputs(self): - data = torch.randn(4, 8, device="cuda") - ref_data = torch.randn(8, 4, device="cpu") - - def reshape_wrapper(data, ref_cpu_data): - # Dummy line to make sure ref_cpu_data - # is included in the captured graph. - ref_cpu_data += 1 - shape = ref_cpu_data.shape - # A call with GPU and CPU inputs. - return torch.reshape(data, shape) - - compiled_model = torch.compile( - reshape_wrapper, - backend="onnxrt", - dynamic=True, - ) - - result = compiled_model(data, ref_data) - - self.assertTrue(torch.allclose(result, data.view(ref_data.shape))) - - def test_no_input(self): - def reshape_wrapper(): - # A model without input. - ones = torch.ones(4, 8) - zeros = torch.zeros(4, 8) - return ones + zeros - - recorded_models = [] - - def record_onnx_model_transform(onnx_model): - # Record the ONNX model seen by the transform. - recorded_models.append(onnx_model) - - compiled_model = torch.compile( - reshape_wrapper, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[ - record_onnx_model_transform, - ] - ), - ) - - result = compiled_model() - - self.assertEqual(len(recorded_models), 1) - # NOTE: Constant folded by optimizer - self.assertTrue( - "Constant" in [node.op_type for node in recorded_models[0].graph.node] - ) - - self.assertEqual(result, torch.ones(4, 8)) - - def test_custom_onnx_transform(self): - # This test consists of 2 parts: - # 1. If a registered ONNX transform is called and recorded a model. - # 2. If a registered ONNX transform is called and changed the model - - # Part 1: Record the ONNX model seen by the transform. - # This list contains the models recorded by record_onnx_model_transform. - recorded_models = [] - - def record_onnx_model_transform(onnx_model): - # Record the ONNX model seen by the transform. - recorded_models.append(onnx_model) - - def example_model(x: torch.Tensor): - y = torch.sigmoid(x) - z = x + y - return z - - compiled_model = torch.compile( - example_model, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[record_onnx_model_transform] - ), - ) - - x = torch.randn(2) - assert len(recorded_models) == 0 - y = compiled_model(x) - assert len(recorded_models) == 1 - - # Part 2: Change the ONNX model seen by the transform so that - # ORT receives a different model. - # NOTE: the function is optimized away by optimizer - def replace_relu_with_sigmoid(onnx_model): - for node in onnx_model.graph.node: - if node.op_type == "Relu": - node.op_type = "Sigmoid" - - def another_example_model(x: torch.Tensor): - y = torch.relu(x) - z = x + y - return z - - another_compiled = torch.compile( - another_example_model, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[ - replace_relu_with_sigmoid, - record_onnx_model_transform, - ] - ), - ) - - another_y = another_compiled(x) - # We have 2 models recorded `record_onnx_model_transform` - # by the 2 torch.compile calls above. - assert len(recorded_models) == 2 - # Since we have changed "Relu" to "Sigmoid" in replace_sigmoid_with_relu, - # the result should be the same to previous y. - torch.testing.assert_close(y, another_y) - # another_example_model still uses "Relu", so the result should be different - # than y. - self.assertFalse(torch.allclose(y, another_example_model(x))) - - -if __name__ == "__main__": - common_utils.run_tests() From 4c1fabf2c9eb0c9773b09ff56761f8361fb60304 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 14 Jul 2025 15:58:20 -0700 Subject: [PATCH 034/457] [DTensor] have split_strategy return OpStrategy instead of TupleStrategy (#158051) **Summary** `split_strategy` used `TupleStrategy` as return type because DTensor sharding propagation's `OpStrategy` support on multi-returns only applies to `Tuple`. However, `TupleStrategy`'s not a good fit for `split` op. `TupleStrategy` was initially introduced to handle the sharding strategy of `foreach_*` ops where the input args can be split into independent subsets regarding sharding decisions, so are the outputs. To address the misuse, this PR adds `OpStrategy` propagation for `List[Tensor]` (note that this support is INCOMPLETE because it only checks the return type to be `torch.ListType`). Nevertheless, the logic for `Tuple` returns also made similar assumption so I think it's fine to unblock in such a way. Besides adding `OpStrategy` support to ops having `List[Tensor]` return type, this PR also changes `split_strategy`'s return from `TupleStrategy` to `OpStrategy`. **Test** `pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158051 Approved by: https://github.com/wconstab, https://github.com/zpcore --- torch/distributed/tensor/_op_schema.py | 7 ++++ torch/distributed/tensor/_ops/_tensor_ops.py | 42 +++++++++++--------- torch/distributed/tensor/_sharding_prop.py | 5 ++- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index acf15c6c0ea4b..54d85aa1b3abe 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -345,6 +345,13 @@ def return_type_tuple_tensor_like(self) -> bool: return_types[0].type, torch.TensorType ) + def return_type_list_tensor_like(self) -> bool: + # returns True if the return type is a List + return_types = self.op._schema.returns + return len(return_types) == 1 and isinstance( + return_types[0].type, torch.ListType + ) + def return_type_tensor(self) -> bool: return_types = self.op._schema.returns # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index a81db1a3b124e..9bdfc90d145d4 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -1074,7 +1074,7 @@ def place(vp: Placement, ip: Placement) -> Placement: ], RuntimeSchemaInfo(1), ) -def split_strategy(op_schema: OpSchema) -> TupleStrategy: +def split_strategy(op_schema: OpSchema) -> OpStrategy: input_strategy = op_schema.args_schema[0] split_size_or_sections = op_schema.args_schema[1] assert isinstance(input_strategy, OpStrategy) @@ -1097,23 +1097,27 @@ def size_split(N, i) -> list: ) assert isinstance(output_size_list, Sized) - split_strategies = [] - - for _ in range(len(output_size_list)): - op_strategy = OpStrategy([]) - - for strategy in input_strategy.strategies: - spec = strategy.output_spec - placements = spec.placements - if is_tensor_dim_sharded(spec, dim=dim): - # if the input is sharded on the split dim, we need to unshard it - placements = unshard_tensor_dim(spec.placements, dim=dim) - - spec = DTensorSpec(spec.mesh, placements) - - op_strategy.strategies.append( - OpSpec(output_specs=spec, input_specs=([spec])) + all_strategies = [] + for strategy in input_strategy.strategies: + spec = strategy.output_spec + placements = spec.placements + if is_tensor_dim_sharded(spec, dim=dim): + # if the input is sharded on the split dim, we need to unshard it + placements = unshard_tensor_dim(spec.placements, dim=dim) + + input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta) + output_specs = tuple( + DTensorSpec(spec.device_mesh, placements) + for _ in range(len(output_size_list)) + ) + all_strategies.append( + OpSpec( + output_specs=output_specs, + input_specs=(input_spec,), + redistribute_cost=[ + generate_redistribute_costs(input_strategy, input_spec) + ], ) - split_strategies.append(op_strategy) + ) - return TupleStrategy(split_strategies) + return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 4b1536644b877..69af19fea26ab 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -353,7 +353,10 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin for _ in range(len(op_schema.op._schema.returns)) ] ) - elif op_schema.return_type_tensor(): + elif ( + op_schema.return_type_tensor() + or op_schema.return_type_list_tensor_like() + ): output_specs = output_strategy.output_specs else: output_specs = None From add0b450bd9907ba9d089c79ca4af96c0590d8ff Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 14 Jul 2025 15:58:20 -0700 Subject: [PATCH 035/457] [DTensor][BE] improve DTensor ops correctness check utils (#158112) **Summary** Implemented the test pattern described in https://github.com/pytorch/pytorch/pull/157991#discussion_r2196363170 as a util method in `DTensorTestBase`. The difference to `DTensorTestBase._test_op` is: 1. allowing users to specify the `Partial` placement. 2. supporting tree-like output structure. **Test** so far only adopt `DTensorTestBase._test_op_on_dtensor` in `DistTensorOpsTest.test_split_on_partial`. `pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158112 Approved by: https://github.com/Skylion007, https://github.com/zpcore ghstack dependencies: #158051 --- test/distributed/tensor/test_tensor_ops.py | 20 +++++--------- .../distributed/_tensor/common_dtensor.py | 27 +++++++++++++++++++ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 3c0e65809c7c7..9be582952f367 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -2,7 +2,6 @@ # Owner(s): ["oncall: distributed"] import torch -import torch.distributed._functional_collectives as funcol from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, @@ -725,24 +724,17 @@ def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int mesh = init_device_mesh(self.device_type, (self.world_size,)) partial_tensor = torch.randn(8, 8, device=self.device_type) - replicate_tensor = partial_tensor.detach().clone() - replicate_tensor = funcol.all_reduce( - replicate_tensor, reduce_op, mesh - ) # all reduce to full tensor - replicate_tensor_list = replicate_tensor.split(split_size, dim=split_dim) - partial_dt = DTensor.from_local( local_tensor=partial_tensor, device_mesh=mesh, placements=[Partial(reduce_op=reduce_op)], ) - partial_dt_list = partial_dt.split(split_size, dim=split_dim) - - replicate_dt_full_tensor_list = [dt.full_tensor() for dt in partial_dt_list] - for replicate_tensor, replicate_dt_full_tensor in zip( - replicate_tensor_list, replicate_dt_full_tensor_list - ): - self.assertEqual(replicate_tensor, replicate_dt_full_tensor) + self._test_op_on_dtensor( + torch.split, + partial_dt, + split_size, + dim=split_dim, + ) if __name__ == "__main__": diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 8e9a9a55f6774..c922e6993af33 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -17,6 +17,7 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, + DTensor, Placement, Replicate, Shard, @@ -403,6 +404,32 @@ def setUp(self) -> None: super().setUp() self._spawn_processes() + def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None: + """ + This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``. + Unlike _test_op where the DTensor sharding is generated by DTensorConverter, + this function takes in DTensor object directly as argument and test the equality + of calling op on full_tensor() and DTensor. + """ + # call full_tensor() on DTensor args/kwargs + args_flattened, args_spec = tree_flatten(args) + full_tensor_args_flattened = tuple( + arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg + for arg in args_flattened + ) + full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec) + full_tensor_kwargs = { + k: v.full_tensor() if isinstance(v, DTensor) else v + for k, v in kwargs.items() + } + + out_flattened, _ = tree_flatten( + op_call(*full_tensor_args, **full_tensor_kwargs) + ) + d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs)) + d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened] + self.assertEqual(out_flattened, d_out_full_tensor_flattened) + # pyre-ignore[2]: def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: out = op_call(*args, **kwargs) From 058fb1790f2c474cd4ecb5ec625eef896c554544 Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 15 Jul 2025 05:06:51 +0000 Subject: [PATCH 036/457] Fix compilation and "import torch" issues for cpython 3.14 (#158184) Beginning of process for 3.14 bringup. State of things from this PR: - Nothing too scary looking from the Dynamo CPython side, nothing we heavily rely on seems to be missing @williamwen42 - The existing check that makes torch.compile() nicely fail is working as expected. So all these empty functions shouldn't cause any weirdness. - The `__module__` update changes look suspicious, we should investigate what is the reason and impact of that, in particular for our public API checking @jbschlosser - Leaving the weakref.py thread safety change as a follow up to keep this a bit simpler. I vendored the whole struct in the meantime FYI @ezyang EDIT: The `__module__` change is even more cursed than I though due to changes to Union and Optional type where the `__module__` field cannot be changed anymore. See https://github.com/python/cpython/issues/132139 for details. For now, I'm just skipping the `__module__` setting for 3.14 which will trip the public API checks. Will revisit once I have a final answer on the cpython issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158184 Approved by: https://github.com/msaroufim --- torch/_dynamo/bytecode_analysis.py | 2 +- torch/ao/quantization/__init__.py | 5 +++- torch/ao/quantization/qconfig.py | 4 ++- torch/ao/quantization/utils.py | 7 +++-- torch/csrc/dynamo/cpython_defs.c | 16 +++++++++++ torch/csrc/dynamo/cpython_includes.h | 17 ++++++++++++ torch/csrc/dynamo/eval_frame.c | 34 +++++++++++++++-------- torch/csrc/dynamo/framelocals_mapping.cpp | 14 ++++++++++ torch/csrc/utils/python_compat.h | 1 + torch/onnx/__init__.py | 1 - torch/utils/weak.py | 29 +++++++++++++++++-- 11 files changed, 111 insertions(+), 19 deletions(-) diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 3252ea91409f9..2de74ee5bf8d2 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -33,7 +33,7 @@ TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) else: TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) -if sys.version_info >= (3, 12): +if (3, 12) <= sys.version_info < (3, 14): TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) if sys.version_info >= (3, 13): TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"]) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index ffc1792fd23fa..cf5a8b99a8941 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs +import sys from typing import Callable, Optional, Union import torch @@ -33,7 +34,9 @@ # ensure __module__ is set correctly for public APIs ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] -ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +if sys.version_info < (3, 14): + ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" + for _f in [ compare_results, extract_results_from_loggers, diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index efee5302ad42a..d9a8fc78bab4a 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import copy +import sys import warnings from collections import namedtuple from typing import Any, Optional, Union @@ -568,7 +569,8 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N QConfigAny = Optional[QConfig] -QConfigAny.__module__ = "torch.ao.quantization.qconfig" +if sys.version_info < (3, 14): + QConfigAny.__module__ = "torch.ao.quantization.qconfig" def _add_module_to_qconfig_obs_ctr( diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index feae45df3b863..a80ae1d8e3de1 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -4,6 +4,7 @@ """ import functools +import sys import warnings from collections import OrderedDict from inspect import getfullargspec, signature @@ -16,7 +17,8 @@ NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] -NodePattern.__module__ = "torch.ao.quantization.utils" +if sys.version_info < (3, 14): + NodePattern.__module__ = "torch.ao.quantization.utils" # This is the Quantizer class instance from torch/quantization/fx/quantize.py. # Define separately to prevent circular imports. @@ -31,7 +33,8 @@ Pattern = Union[ Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any ] -Pattern.__module__ = "torch.ao.quantization.utils" +if sys.version_info < (3, 14): + Pattern.__module__ = "torch.ao.quantization.utils" # TODO: maybe rename this to MatchInputNode diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c index b68ef894aeaa2..244d4165d5e87 100644 --- a/torch/csrc/dynamo/cpython_defs.c +++ b/torch/csrc/dynamo/cpython_defs.c @@ -2,6 +2,20 @@ #include #include +#if IS_PYTHON_3_14_PLUS + +const uint8_t* THP_PyOpcode_Caches = NULL; +const int THP_PyOpcode_Caches_size = 0; + +void +THP_PyThreadState_PopFrame(PyThreadState *tstate, _PyInterpreterFrame * frame) +{} +void +THP_PyFrame_Clear(_PyInterpreterFrame *frame) +{} + +#else + #if IS_PYTHON_3_11_PLUS #define Py_BUILD_CORE @@ -360,3 +374,5 @@ const uint8_t* THP_PyOpcode_Caches = NULL; const int THP_PyOpcode_Caches_size = 0; #endif + +#endif // IS_PYTHON_3_14_PLUS \ No newline at end of file diff --git a/torch/csrc/dynamo/cpython_includes.h b/torch/csrc/dynamo/cpython_includes.h index 6b99c1d5aec8e..616be16563cfa 100644 --- a/torch/csrc/dynamo/cpython_includes.h +++ b/torch/csrc/dynamo/cpython_includes.h @@ -21,6 +21,14 @@ #if IS_PYTHON_3_11_PLUS #include +#if IS_PYTHON_3_14_PLUS +#include +#include +#endif +#endif + +#if IS_PYTHON_3_14_PLUS +#include #endif #undef Py_BUILD_CORE @@ -30,6 +38,13 @@ extern "C" { #endif +#if IS_PYTHON_3_14_PLUS + +#define F_CODE(x) (PyCodeObject*)PyStackRef_AsPyObjectBorrow(x->f_executable) +#define PREV_INSTR(x) (x)->instr_ptr + +#else + #if IS_PYTHON_3_13_PLUS #define F_CODE(x) ((PyCodeObject*)(x)->f_executable) #define PREV_INSTR(x) (x)->instr_ptr @@ -38,6 +53,8 @@ extern "C" { #define PREV_INSTR(x) (x)->prev_instr #endif +#endif // IS_PYTHON_3_14_PLUS + #if IS_PYTHON_3_12_PLUS #define FUNC(x) ((x)->f_funcobj) #else diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index f413782b2d301..72bb8839bac35 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -224,17 +224,6 @@ const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) { return PyUnicode_AsUTF8(F_CODE(frame)->co_name); } -void clear_old_frame_if_python_312_plus( - PyThreadState* tstate, - THP_EVAL_API_FRAME_OBJECT* frame) { -#if IS_PYTHON_3_12_PLUS - - THP_PyFrame_Clear(frame); - THP_PyThreadState_PopFrame(tstate, frame); - -#endif -} - static PyObject* dynamo_eval_custom_code_impl( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -485,6 +474,18 @@ static PyObject* dynamo__custom_eval_frame_shim( static void enable_eval_frame_shim(PyThreadState* tstate) {} static void enable_eval_frame_default(PyThreadState* tstate) {} +PyObject* dynamo_eval_custom_code( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame, + PyCodeObject* code, + const char* trace_annotation, + int throw_flag) {} +THPPyInterpreterFrame* THPPyInterpreterFrame_New( + THP_EVAL_API_FRAME_OBJECT* frame) {} +PyObject* dynamo_eval_frame_default( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame, + int throw_flag) {} static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL}; @@ -498,6 +499,17 @@ static PyTypeObject THPPyInterpreterFrameType = { #endif // !(IS_PYTHON_3_14_PLUS) +void clear_old_frame_if_python_312_plus( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame) { +#if IS_PYTHON_3_12_PLUS + + THP_PyFrame_Clear(frame); + THP_PyThreadState_PopFrame(tstate, frame); + +#endif +} + static PyObject* increment_working_threads( PyThreadState* tstate, PyObject* module) { diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp index b839fb26fc91a..c4ee36d87767b 100644 --- a/torch/csrc/dynamo/framelocals_mapping.cpp +++ b/torch/csrc/dynamo/framelocals_mapping.cpp @@ -26,9 +26,13 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) PyCodeObject* co = F_CODE(frame); _framelocals.resize(co->co_nlocalsplus, nullptr); +#if IS_PYTHON_3_14_PLUS + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else if (!frame->stacktop) { return; } +#endif auto update_framelocals = [&](int i, PyObject* value) { _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); @@ -53,11 +57,21 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) }; auto offset = co->co_nlocalsplus - co->co_nfreevars; +#if IS_PYTHON_3_14_PLUS + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else for (int i = 0; i < offset; i++) { update_framelocals(i, frame->localsplus[i]); } +#endif + // Get references to closure variables +#if IS_PYTHON_3_14_PLUS + PyObject* closure; + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure; +#endif for (int i = 0; i < co->co_nfreevars; i++) { update_framelocals(offset + i, PyTuple_GET_ITEM(closure, i)); } diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index a1537611cc47f..16292e4fd0308 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -13,6 +13,7 @@ extern "C" { #define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000 #define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000 #define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000 +#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000 static inline int PyCode_GetNCellvars(PyCodeObject* code) { // gh-26364 added co_ncellvars to Python 3.11.0rc1 diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 410b34b042cf0..7db778ef08e60 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -100,7 +100,6 @@ OnnxExporterError.__module__ = "torch.onnx" _OrtBackend.__module__ = "torch.onnx" _OrtBackendOptions.__module__ = "torch.onnx" -_OrtExecutionProvider.__module__ = "torch.onnx" enable_fake_mode.__module__ = "torch.onnx" is_onnxrt_backend_supported.__module__ = "torch.onnx" diff --git a/torch/utils/weak.py b/torch/utils/weak.py index 8bf2ba5ed02b4..9c7218cb2ad3b 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -3,8 +3,6 @@ import collections.abc as _collections_abc import weakref - -from _weakrefset import _IterationGuard # type: ignore[attr-defined] from collections.abc import Mapping, MutableMapping from weakref import ref @@ -22,6 +20,33 @@ ] +# TODO: make weakref properly thread safe following +# https://github.com/python/cpython/pull/125325 +class _IterationGuard: + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + # This file defines a variant of WeakKeyDictionary that overrides the hashing # behavior of the key to use object identity, rather than the builtin # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their From 9cd521de4dad5fc6bca94e253a9334b9a521acb0 Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Tue, 15 Jul 2025 05:44:33 +0000 Subject: [PATCH 037/457] Fix torchrec multiprocess tests (#158159) Summary: The new version of `get_device_tflops` imported something from testing, which imported common_utils.py, which disabled global flags. Test Plan: Fixing existing tests Rollback Plan: Differential Revision: D78192700 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158159 Approved by: https://github.com/nipung90, https://github.com/huydhn --- torch/_inductor/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 10701d0d8b2d5..d22d67cecff21 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2178,7 +2178,10 @@ def get_device_tflops(dtype: torch.dtype) -> float: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops - from torch.testing._internal.common_cuda import SM80OrLater + SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 0, + ) assert dtype in (torch.float16, torch.bfloat16, torch.float32) From 3341c131b767a4036c152624c1e43baaf24cadf9 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Tue, 15 Jul 2025 05:57:23 +0000 Subject: [PATCH 038/457] [SymmMem] Fix NCCL Hang in NVSHMEM Triton Wait Until Test (#158167) The `test_triton_wait_until` test was hanging due to an NCCL synchronization issue stemming from mismatched NVSHMEM operations. Specifically, the flag variable was updated using `nvshmemx_signal_op` (a signaling operation), but waited on with `nvshmem_wait_until` (intended for put/get updates). Per NVSHMEM documentation (see documentation reference section below), signal-updated variables require `nvshmem_signal_wait_until` for proper completion guarantees, so the mismatch caused a deadlock and NCCL hang. **Fix:** - A simple fix was to replace the flag update with a regular `nvshmem_putmem_block` (via `put_kernel`) to match `nvshmem_wait_until`. I also added a fence (`nvshmem_fence`) between data and flag puts on the sender (Rank 1) for ordered delivery. - In a follow-up PR I will add a kernel/test to demonstrate usage of `nvshmemx_signal_op` **Testing:** - I ran `python test/distributed/test_nvshmem_triton.py` and `python test/distributed/test_nvshmem_triton.py -k test_triton_wait_until` - I also verified with debug prints (Sender completes puts/fence before receiver's wait returns, and assertions confirm correct state). Multiple runs show no hangs or failures. **Documentation Referenced:** - [NVSHMEM Point-To-Point Synchronization](https://docs.nvidia.com/nvshmem/api/gen/api/sync.html) explicitly states: *"the sig_addr object at the calling PE is expected only to be updated as a signal, through the signaling operations available in Section NVSHMEM_PUT_SIGNAL and Section NVSHMEM_PUT_SIGNAL_NBI"* - [NVIDIA's Official Ring Broadcast Example](https://docs.nvidia.com/nvshmem/api/examples.html) demonstrates the correct pairing: `nvshmemx_signal_op` with `nvshmem_signal_wait_until` (not `nvshmem_wait_until`) - [NVSHMEM Signaling Operations](https://docs.nvidia.com/nvshmem/api/gen/api/signal.html) documents that signal operations work on special "signal data objects" with specific atomicity guarantees distinct from regular RMA operations Pull Request resolved: https://github.com/pytorch/pytorch/pull/158167 Approved by: https://github.com/Skylion007, https://github.com/fduwjj --- test/distributed/test_nvshmem_triton.py | 58 +++++++++++++++---------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 2aabf92427841..8958ec4eb84e2 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -13,7 +13,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, - skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, skipIfRocm, ) @@ -187,7 +186,7 @@ def test_triton_put(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - peer = 1 - rank + peer = (self.world_size - 1) - rank if rank == 0: dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] @@ -226,7 +225,7 @@ def test_triton_get(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) dist.barrier() - peer = 1 - rank + peer = (self.world_size - 1) - rank if rank == 1: # Rank 1 gets data from rank 0 dst_ptr = out_hdl.buffer_ptrs[rank] @@ -312,7 +311,7 @@ def test_triton_put_signal_set(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set SIGNAL_VAL = 1 # Signal completion value NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until @@ -377,7 +376,7 @@ def test_triton_put_signal_add(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_SIGNAL_ADD = 5 # atomic add operation SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD NVSHMEM_CMP_EQ = 0 @@ -413,50 +412,54 @@ def test_triton_put_signal_add(self) -> None: flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) ) - # This test hangs. TODO: investigate why. - @skip_but_pass_in_sandcastle("Hangs") @skipIfRocm @requires_triton() def test_triton_wait_until(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() + nvshmem_lib = nvshmem.enable_triton() group_name = dist.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = (self.world_size - 1) - rank + NVSHMEM_CMP_EQ = 0 # from nvshmem.h - # Data buffers + # Allocate symmetric buffers msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize val = 13 flag_val = 21 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - peer = 1 - rank - NVSHMEM_CMP_EQ = 0 # from nvshmem.h - NVSHMEM_SIGNAL_SET = 0 # atomic set operation - if rank == 0: # Rank 0 waits for the flag to be set by Rank 1, then checks the data ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, ) + torch.testing.assert_close( - out, val * torch.ones(numel, dtype=dtype, device=self.device) + out, + val * torch.ones(numel, dtype=dtype, device=self.device), ) if rank == 1: # Rank 1 puts data into Rank 0's output buffer - dst_ptr = out_hdl.buffer_ptrs[rank] + dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( dst_ptr, src_ptr, @@ -465,12 +468,21 @@ def test_triton_wait_until(self) -> None: extern_libs=nvshmem_lib, ) - # Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op - sig_addr = out_hdl.signal_pad_ptrs[rank] - signal_op_kernel[(1, 1, 1)]( - sig_addr, - signal=flag_val, - sig_op=NVSHMEM_SIGNAL_SET, + # Fence to order data put before flag put + @triton.jit + def fence_kernel(): + nvshmem.fence() + + fence_kernel[(1, 1, 1)](extern_libs=nvshmem_lib) + + # Put the flag value (do not use signal_op here) + flag_src = torch.tensor([flag_val], dtype=torch.int64, device=self.device) + flag_dst_ptr = out_hdl.signal_pad_ptrs[peer] + + put_kernel[(1, 1, 1)]( + flag_dst_ptr, + flag_src.data_ptr(), + numel=1, peer=peer, extern_libs=nvshmem_lib, ) @@ -484,7 +496,7 @@ def test_triton_signal_wait_until(self) -> None: group_name = dist.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = 1 - rank + peer = (self.world_size - 1) - rank # NVSHMEM constants from documentation NVSHMEM_CMP_EQ = 0 # equal comparison @@ -560,7 +572,7 @@ def test_triton_fence(self) -> None: group_name = dist.group.WORLD.group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = 1 - rank + peer = (self.world_size - 1) - rank # Message configuration msg_size_bytes = 8 dtype = torch.int8 @@ -646,7 +658,7 @@ def test_triton_quiet(self) -> None: out_hdl = symm_mem.rendezvous(out, group=group_name) # Use signal pad as completion flag flag_val = 42 - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_CMP_EQ = 0 if rank == 0: From c8c221c0b3abbb8b5e20138080644dd5f5cd0aa1 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Tue, 15 Jul 2025 06:01:57 +0000 Subject: [PATCH 039/457] [Inductor][Float8] Add float8_e4m3fn into assertion dtype list. (#157684) Fix assert issue. Add float8_e4m3fn into dtype list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157684 Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel --- .../src/ATen/native/quantized/cpu/qlinear.cpp | 3 +- .../native/quantized/cpu/qlinear_prepack.cpp | 5 +- test/inductor/test_mkldnn_pattern_matcher.py | 98 +++++++++++++++++++ torch/_inductor/fx_passes/quantization.py | 8 +- torch/_inductor/mkldnn_lowerings.py | 4 +- torch/_meta_registrations.py | 16 ++- 6 files changed, 126 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 502839a7d909c..644ca6e67079e 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -1118,8 +1118,9 @@ static at::Tensor linear_int8_with_onednn_weight( if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { #endif // Fall back to ref impl on old platforms because not supported + // Transpose weight to align with behavior in oneDNN return fp8_qlinear_onednn_ref( - input, input_scale, onednn_weight, weight_scales, bias, + input, input_scale, onednn_weight.t(), weight_scales, bias, output_scale, output_dtype, other, other_scale, binary_post_op, binary_alpha, unary_post_op, unary_post_op_args, unary_post_op_algorithm); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 55ec1d8148bde..3bd68feca1c2f 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -305,11 +305,12 @@ static inline at::Tensor pack_weight_to_onednn_tensor( #if defined(__powerpc__) if (is_fp8){ #else - if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { + if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { #endif // oneDNN's fp8 requires AMX support // If AMX is not available, fall back to reference implementation - return weight; + // Transpose weight to align with behavior in oneDNN + return weight.t(); } std::vector w_dims = weight.sizes().vec(); auto w_data_type = is_fp8 diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 7760bfd834efd..bccc0e6e42fda 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2952,6 +2952,104 @@ def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) + def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"): + dtype = torch.float8_e4m3fn + qlinear_prepack = torch.ops.onednn.qlinear_prepack + post_op_algo = "none" + unary_post_op_args = () + batch_size = 1 + output_dtype = torch.float8_e4m3fn + y_scale, y_zp = 0.07, 0 + ic = 4 + oc = 16 + + torch._dynamo.reset() + used_y_scale = y_scale + used_y_zp = y_zp + x = torch.rand(batch_size, ic) + w = torch.rand(oc, ic) + qx = x.to(dtype) + qw = w.to(dtype) + x_scale = 0.5 + w_scales = torch.randn(oc) + b = torch.rand(oc) + + x_zp = 0 + w_zps = torch.zeros_like(w_scales, dtype=torch.int) + + if post_op == "none": + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.qw_packed = qlinear_prepack(qw, x.shape) + + def forward(self, qx): + qy = qlinear_op( + qx, + x_scale, + x_zp, + self.qw_packed, + w_scales, + w_zps, + b, + used_y_scale, + used_y_zp, + output_dtype, + post_op, + unary_post_op_args, + post_op_algo, + ) + return qy + + elif post_op == "add": + x2 = torch.rand(batch_size, oc) + binary_alpha = 1.0 # we only support alpha=1.0 now + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.qw_packed = qlinear_prepack(qw, x.shape) + + def forward(self, qx): + qy = qlinear_op( + qx, + x_scale, + x_zp, + self.qw_packed, + w_scales, + w_zps, + x2, + b, + used_y_scale, + used_y_zp, + output_dtype, + 1.0, + 0, + "add", + binary_alpha, + "none", + unary_post_op_args, + post_op_algo, + ) + return qy + + with torch.no_grad(): + model = Mod() + y_refe = model(qx) + y_test = torch.compile(model)(qx) + self.assertEqual(y_refe.float(), y_test.float()) + + @skipIfNoONEDNN + def test_qlinear_fp8_inductor_cpu(self): + qlinear_op = torch.ops.onednn.qlinear_pointwise.default + self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "none") + + @skipIfNoONEDNN + def test_qlinear_add_fp8_inductor_cpu(self): + qlinear_op = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add") + def _qlinear_dequant_promotion_test_helper( self, inputs, diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 862df99a41e50..70dfe9ae43b35 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -72,7 +72,13 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + ] return output_dtype diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index e7981bc8746bc..3b3a7b072534a 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -675,8 +675,8 @@ def qlinear_unary( algorithm, layout=None, ): - assert packed_weight.get_dtype() is torch.int8, ( - "Only int8 weights are supported by oneDNN qlinear." + assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], ( + "Only int8 and e4m3fn weights are supported by oneDNN qlinear." ) x_size = x.get_size() if len(x_size) > 2: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index acb7ab2e5a053..4d8079d9a7618 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2789,7 +2789,13 @@ def meta_qlinear_pointwise( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] - assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8] + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + ] out = x.new_empty(output_shape, dtype=output_dtype) return out @@ -2820,7 +2826,13 @@ def meta_qlinear_pointwise_binary( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] - assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + torch.float8_e4m3fn, + ] out = x.new_empty(output_shape, dtype=output_dtype) return out From 6c5227ba00a2904365af566c24b4681cd01a041c Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 15 Jul 2025 07:04:54 +0000 Subject: [PATCH 040/457] [CI] Fixes CI for CUDA Version > 12.9 (#157385) Compute capabilities older than volta (inclusive) is no longer supported in CUDA Version > 12.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157385 Approved by: https://github.com/huydhn --- test/test_cpp_extensions_jit.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index d671e3f874c96..84f5923697a2d 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -322,12 +322,15 @@ def test_jit_cuda_archflags(self): [f"{capability[0]}{capability[1]}" for capability in capabilities], None, ), - "Maxwell+Tegra;6.1": (["53", "61"], None), - "Volta": (["70"], ["70"]), } archflags["7.5+PTX"] = (["75"], ["75"]) - archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) - if int(torch.version.cuda.split(".")[0]) < 12: + major, minor = map(int, torch.version.cuda.split(".")[:2]) + if major < 12 or (major == 12 and minor <= 9): + # Compute capability <= 7.0 is only supported up to CUDA 12.9 + archflags["Maxwell+Tegra;6.1"] = (["53", "61"], None) + archflags["Volta"] = ((["70"], ["70"]),) + archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) + if major < 12: # CUDA 12 drops compute capability < 5.0 archflags["Pascal 3.5"] = (["35", "60", "61"], None) From 1b389025ba0cc640e07991314bfba8b6ca385bd2 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 15 Jul 2025 12:26:58 +0800 Subject: [PATCH 041/457] Refactor and Improve the OpenReg Module (#158090) ---- # Refactor and Improve the OpenReg Module ## Background Since PrivateUse1 has become the main path for integrating new devices with PyTorch, there have been some feature requests related to PrivateUse1 regarding interfaces, documentation, reference examples, etc., such as the following: - https://github.com/pytorch/pytorch/issues/155864 - https://github.com/pytorch/pytorch/issues/144955 - https://github.com/pytorch/pytorch/issues/144845 Taking these requests into consideration and combining them with the position of OpenReg, which is currently used as the test backend for PrivateUse1, I'm planning to make the following optimizations: - Optimize the implementation of OpenReg to make it align with the standard specifications for real backend (C++) access, serving as a reference for new device integration code. - Add comprehensive documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html) to guide new accelerator integration, functioning as a reference manual. ## Design Principles: - Minimization Principle: Keep the code small and clear; only implement the minimum set of code required for verification and as an integration reference. - Authenticity Principle: Integrate OpenReg in the same way that real accelerators access PyTorch. ## More Infos: Pleaes refer to [this](https://github.com/pytorch/pytorch/blob/6b8020f1abc855203358a24e9f560810eb5b720e/test/cpp_extensions/open_registration_extension/torch_openreg/README.md) for more information about `OpenReg`. ## Current Progress: - Refer to the implementation of [torch_xla](https://github.com/pytorch/xla) to refactor all of OpenReg's code, making it easier to understand. - Ensure all tests in [test/test_openreg.py](https://github.com/FFFrog/pytorch/blob/openreg/test/test_openreg.py) pass after refactoring. ## Next Steps: - Add more features to cover all integration points. - Gradually add user guides and documentation to the [developer notes](https://docs.pytorch.org/docs/main/notes.html). Pull Request resolved: https://github.com/pytorch/pytorch/pull/158090 Approved by: https://github.com/seemethere, https://github.com/albanD --- .../open_registration_extension/README.md | 37 -- .../pytorch_openreg/__init__.py | 122 ----- .../pytorch_openreg/_aten_impl.py | 186 -------- .../pytorch_openreg/_device_daemon.py | 391 ---------------- .../pytorch_openreg/_meta_parser.py | 103 ----- .../pytorch_openreg/csrc/Module.cpp | 51 --- .../pytorch_openreg/csrc/OpenReg.h | 50 --- .../pytorch_openreg/csrc/OpenRegHooks.cpp | 350 --------------- .../pytorch_openreg/csrc/OpenRegMem.cpp | 418 ------------------ .../open_registration_extension/setup.py | 78 ---- .../torch_openreg/CMakeLists.txt | 38 ++ .../torch_openreg/README.md | 177 ++++++++ .../torch_openreg/csrc/CMakeLists.txt | 12 + .../torch_openreg/csrc/aten/OpenRegExtra.cpp | 138 ++++++ .../csrc/aten/OpenRegMinimal.cpp | 128 ++++++ .../torch_openreg/csrc/aten/native/Common.h | 106 +++++ .../torch_openreg/csrc/aten/native/Extra.cpp | 238 ++++++++++ .../torch_openreg/csrc/aten/native/Extra.h | 70 +++ .../csrc/aten/native/Minimal.cpp | 173 ++++++++ .../torch_openreg/csrc/aten/native/Minimal.h | 67 +++ .../csrc/runtime/OpenRegDeviceAllocator.cpp | 8 + .../csrc/runtime/OpenRegDeviceAllocator.h | 43 ++ .../csrc/runtime/OpenRegFunctions.cpp | 73 +++ .../csrc/runtime/OpenRegFunctions.h | 16 + .../csrc/runtime/OpenRegGenerator.cpp | 28 ++ .../csrc/runtime/OpenRegGenerator.h | 21 + .../csrc/runtime/OpenRegGuard.cpp | 7 + .../torch_openreg/csrc/runtime/OpenRegGuard.h | 197 +++++++++ .../csrc/runtime/OpenRegHooks.cpp | 11 + .../torch_openreg/csrc/runtime/OpenRegHooks.h | 41 ++ .../csrc/runtime/OpenRegHostAllocator.cpp | 8 + .../csrc/runtime/OpenRegHostAllocator.h | 48 ++ .../csrc/runtime/OpenRegSerialization.cpp | 48 ++ .../csrc/runtime/OpenRegSerialization.h | 10 + .../torch_openreg/requirements.txt | 2 + .../torch_openreg/setup.py | 102 +++++ .../third_party/openreg/CMakeLists.txt | 11 + .../third_party/openreg/README.md | 137 ++++++ .../third_party/openreg/csrc/device.cpp | 35 ++ .../third_party/openreg/csrc/memory.cpp | 249 +++++++++++ .../third_party/openreg/include/openreg.h | 49 ++ .../torch_openreg/torch_openreg/__init__.py | 8 + .../torch_openreg/csrc/CMakeLists.txt | 12 + .../torch_openreg/csrc/Module.cpp | 99 +++++ .../torch_openreg/torch_openreg/csrc/stub.c | 15 + .../torch_openreg/openreg/__init__.py | 72 +++ .../torch_openreg/openreg/random.py | 60 +++ test/run_test.py | 7 +- ...cpp_extensions_open_device_registration.py | 2 +- test/test_openreg.py | 6 +- test/test_transformers_privateuse1.py | 2 +- 51 files changed, 2568 insertions(+), 1792 deletions(-) delete mode 100644 test/cpp_extensions/open_registration_extension/README.md delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp delete mode 100644 test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp delete mode 100644 test/cpp_extensions/open_registration_extension/setup.py create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/README.md create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/setup.py create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py create mode 100644 test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md deleted file mode 100644 index cf32c3afbb06e..0000000000000 --- a/test/cpp_extensions/open_registration_extension/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# PyTorch OpenReg - -This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core. - -## How to use - -Install as standalone with `python -m pip install -e .` (or `python -m pip install .`) -from this folder. You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`. - -## Design principles - -For simplicity anything that can be implemented from python is done so. -A real implementation will most likely want to call these different APIs from c++ directly. - -The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing. - -The codebase is split as follows: - -- `pytorch_openreg/__init__.py` - - imports torch to get core state initialized. - - imports `._aten_impl` to register our aten op implementations to torch. - - imports `.C` to load our c++ extension that registers more ops, allocator and hooks. - - renames the PrivateUse1 backend and register our python-side module. -- `pytorch_openreg/_aten_impl.py` - - Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation. -- `pytorch_openreg/_device_daemon.py` - - contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers). - - contains `Driver`, which as user-process driver to deal with some information needed to be done in driver. - - contains `Executor`, which as device-process exector to do something related device logic. -- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. - - The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor. - -## Next steps - -The main next step would be to: - -- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py deleted file mode 100644 index 05b8955b6557b..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py +++ /dev/null @@ -1,122 +0,0 @@ -import types - -import torch - -# Create our python implementation dict so that the C++ module -# can access it during its initialization and also register aten impls. -from ._aten_impl import impl_factory as impl_factory # noqa: F401 -from ._device_daemon import driver - - -# Load the C++ Module -import pytorch_openreg._C # isort:skip # type: ignore[import] # noqa: F401 - - -def _create_module(): - module = types.ModuleType("_OpenRegMod") - - class device: - r"""Context-manager that changes the selected device. - - Args: - device (torch.device or int): device index to select. It's a no-op if - this argument is a negative integer or ``None``. - """ - - def __init__(self, device): - self.idx = torch.accelerator._get_device_index(device, optional=True) - self.prev_idx = -1 - - def __enter__(self): - self.prev_idx = driver.exec("exchangeDevice", self.idx) - - def __exit__(self, type, value, traceback): - self.idx = driver.exec("uncheckedSetDevice", self.prev_idx) - return False - - def device_count() -> int: - return driver.exec("deviceCount") - - def is_available(): - return True - - def current_device(): - return torch.accelerator.current_device_index() - - def get_rng_state(device="openreg"): - if isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("openreg", device) - idx = device.index - if idx is None: - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - return default_generator.get_state() - - def set_rng_state(new_state, device="openreg"): - if isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("openreg", device) - idx = device.index - if idx is None: - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.set_state(new_state) - - def initial_seed() -> int: - _lazy_init() - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - return default_generator.initial_seed() - - def manual_seed(seed: int) -> None: - seed = int(seed) - - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.manual_seed(seed) - - def manual_seed_all(seed: int) -> None: - seed = int(seed) - - for idx in range(device_count()): - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.manual_seed(seed) - - def is_initialized(): - return module._initialized - - def _is_in_bad_fork(): - return False - - def _lazy_init(): - if is_initialized(): - return - pytorch_openreg._C._init() - module._initialized = True - - module.is_available = is_available # type: ignore[assignment] - - module._initialized = False # type: ignore[assignment] - module._lazy_init = _lazy_init # type: ignore[assignment] - module.is_initialized = is_initialized # type: ignore[assignment] - - module.device = device # type: ignore[assignment] - module.device_count = device_count # type: ignore[assignment] - module.current_device = current_device # type: ignore[assignment] - module.get_rng_state = get_rng_state # type: ignore[assignment] - module.set_rng_state = set_rng_state # type: ignore[assignment] - module._is_in_bad_fork = _is_in_bad_fork # type: ignore[assignment] - module.initial_seed = initial_seed # type: ignore[assignment] - module.manual_seed = manual_seed # type: ignore[assignment] - module.manual_seed_all = manual_seed_all # type: ignore[assignment] - - return module - - -# Set all the appropriate state on PyTorch -torch.utils.rename_privateuse1_backend("openreg") -torch._register_device_module("openreg", _create_module()) -torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py deleted file mode 100644 index d4c49bd28d458..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ /dev/null @@ -1,186 +0,0 @@ -import logging - -import torch -from torch.utils._pytree import tree_any - - -log = logging.getLogger(__name__) - -from ._device_daemon import driver -from ._meta_parser import prepare_for_sending, to_device_no_copy - - -_IMPL_REGISTRY = {} - - -def impl_factory(name): - if name in _IMPL_REGISTRY: - return _IMPL_REGISTRY[name] - - def _(*args, **kwargs): - log.info("Calling hook %s", name) - return driver.exec(name, *args, **kwargs) - - _IMPL_REGISTRY[name] = _ - return _ - - -def _openreg_kernel_fallback(op, *args, **kwargs): - def get_tensor_device(*args): - for arg in args: - if isinstance(arg, torch.Tensor) and arg.device.type == "openreg": - return arg.device - - device = get_tensor_device(*args) - if device is None: - return _kernel_fallback(op, *args, **kwargs) - - # Mimicks the DeviceGuard system we have in aten - with torch.openreg.device(device): # type: ignore[misc] - return _kernel_fallback(op, *args, **kwargs) - - -def _kernel_fallback(op, *args, **kwargs): - log.info("Calling kernel %s", op) - - op_name = None - post_process = None - if "out" in op._overloadname: - # Note that all structured native op will call here - if isinstance(kwargs["out"], tuple): - raise RuntimeError(f"out= variant {op} with tuple out= not supported") - if kwargs["out"].nelement() == 0: - # Out variant that needs a resize, convert to an out of place - # and handle generically below - orig_out = kwargs["out"] - del kwargs["out"] - if op._overloadname != "out": - raise RuntimeError( - "Cannot retranslate non-default out= variant form 0 size" - ) - op = op.overloadpacket.default - - def _post_process(): - nonlocal real_res - orig_out.set_(real_res) - real_res = orig_out - - post_process = _post_process - - else: - # No metadata update to do, just run the op on the device - op_name = op.overloadpacket._qualified_op_name - real_res = kwargs["out"] - elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): - # No Tensor argument means factory function - # They should decompose and be handled in our c++ side directly - raise RuntimeError(f"{op} not handled yet.") - elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: - # Only handle inplace ops returning their first arg - assert len(args) >= 1, f"Inplace {op} needs at least one arg" - assert len(op._schema.returns) == 1, ( - f"NYI Inplace {op} with more than one return" - ) - op_name = op.overloadpacket._qualified_op_name - real_res = args[0] - elif any(r.alias_info is not None for r in op._schema.returns): - # View ops - if op is torch.ops.aten.view.default: - return torch.ops.aten._unsafe_view(*args, **kwargs) - raise RuntimeError(f"{op} view op is not handled yet") - - if op_name is None: - # 1. Compute updated metadata - if torch.Tag.dynamic_output_shape not in op.tags: - # Usual case: run the meta op to see the output metadata - meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) - meta_res = op(*meta_args, **meta_kwargs) - - # 2. Allocate the output - real_res, _ = to_device_no_copy("openreg", meta_res, {}) - else: - # Slow version for data-dependent functions: - # Run the op on the device just to get the output shape - args_, kwargs_ = prepare_for_sending(args, kwargs) - shape = driver.exec( - "get_op_output_shape", - op.overloadpacket._qualified_op_name, - args_, - kwargs_, - ) - - # 2. Allocate the output - real_res = args[0].new(shape) - - # 3. Move to out variant - kwargs["out"] = real_res - # Let overload resolution find the out= overload - op_name = op.overloadpacket._qualified_op_name - - # 4. Run the compute and populate the output on the device - args, kwargs = prepare_for_sending(args, kwargs) - driver.exec("run_op", op_name, args, kwargs) - - if post_process is not None: - post_process() - - return real_res - - -def copy_from_device(from_): - with torch.openreg.device(from_.device): # type: ignore[misc] - args, _ = prepare_for_sending((from_,), {}) - return driver.exec("send_data", *args) - - -def copy_from_host_to_device(from_, to_): - with torch.openreg.device(to_.device): # type: ignore[misc] - args, _ = prepare_for_sending((to_,), {}) - driver.exec("recv_data", from_, *args) - return to_ - - -def _copy_from(from_, to_): - if from_.device.type == to_.device.type: - assert from_.device.type == "openreg" - if from_.device.index == to_.device.index: - op = torch.ops.aten.copy_.default - return _openreg_kernel_fallback(op, to_, from_) - else: - host_mem = copy_from_device(from_) - return copy_from_host_to_device(host_mem, to_) - elif from_.device.type == "openreg": - host_mem = copy_from_device(from_) - return to_.copy_(host_mem) - elif to_.device.type == "openreg": - return copy_from_host_to_device(from_, to_) - else: - raise RuntimeError("Should not happen") - - -def _set_source_tensor(ten1, ten2): - return torch.ops.aten.set_.source_Storage_storage_offset( - ten1, - ten2.untyped_storage(), - ten2.storage_offset(), - ten2.size(), - ten2.stride(), - ) - - -def _local_scalar_dense(ten): - host_mem = copy_from_device(ten) - return host_mem.item() - - -_openreg_lib = torch.library.Library("_", "IMPL") -_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") - -_openreg_lib_aten = torch.library.Library("aten", "IMPL") -_openreg_lib_aten.impl("_copy_from", _copy_from, dispatch_key="PrivateUse1") -_openreg_lib_aten.impl( - "set_.source_Tensor", _set_source_tensor, dispatch_key="PrivateUse1" -) -_openreg_lib_aten.impl( - "_local_scalar_dense", _local_scalar_dense, dispatch_key="PrivateUse1" -) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py deleted file mode 100644 index d339869635001..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py +++ /dev/null @@ -1,391 +0,0 @@ -import ctypes -import logging -import threading -import time - -import torch - -from ._meta_parser import ( - OpenRegTensorData, - receive_after_sending, - safe_str, - validate_send_queue_args, -) - - -log = logging.getLogger(__name__) -mp_context = torch.multiprocessing.get_context("spawn") - -# Constant properties of our device -NUM_DEVICES = 2 - - -# Our allocator -class Allocator: - def __init__(self): - self.allocated = {} - - def malloc(self, size): - mem = ctypes.create_string_buffer(size) - ptr = ctypes.addressof(mem) - self.allocated[ptr] = (size, mem) - return ptr - - def free(self, ptr): - if ptr not in self.allocated: - return False - else: - del self.allocated[ptr] - return True - - -class HostAllocator(Allocator): - def is_pinned_ptr(self, ptr): - return ptr in self.allocated or any( - ptr_ <= ptr and ptr < ptr_ + size - for ptr_, (size, _) in self.allocated.items() - ) - - -class DeviceAllocator(Allocator): - def tensor_from_meta(self, meta): - def create_tensor_from_data_ptr(ptr, size): - storage = torch._C._construct_storage_from_data_pointer( - ptr, torch.device("cpu"), size - ) - return torch.Tensor(storage) - - found_base = None - # Usual case, we're receiving a known Tensor - if meta.data_ptr in self.allocated: - found_base = create_tensor_from_data_ptr( - meta.data_ptr, self.allocated[meta.data_ptr][0] - ) - - # Might be a rewrap of another storage at a different offset - # Slow path to try and find the corresponding storage - if found_base is None: - for tag, (size, _) in self.allocated.items(): - # t is always a 1D uint8 storage! - if meta.data_ptr > tag and meta.data_ptr < tag + size: - # Blame @ngimel for this - slice_size = size - (meta.data_ptr - tag) - found_base = create_tensor_from_data_ptr(meta.data_ptr, slice_size) - - # Might be an empty tensor - if found_base is None and meta.nelem_in_bytes == 0: - found_base = torch.tensor((), dtype=torch.uint8) - - # This pointer is not allocated here, segfault ! - if found_base is None: - log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) - log.info("Trying to access %s", meta) - raise RuntimeError("SEGFAULT!") - - # Raw 1d uint8 data - raw = found_base - # Reinterpret cast in the right dtype - as_dtype = raw.view(dtype=meta.dtype) - # View to the right shape/stride/offset - view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) - return view - - -def register(registry): - def func(fn): - registry[fn.__name__] = fn - return fn - - return func - - -class Driver: - def __init__(self, num_devices): - super().__init__() - self.num_devices = num_devices - self.is_initialized = False - - # State of our driver - self.curr_device_idx = 0 - self.curr_streams = {} - - # Allocated memory belongs to which device - self.memory_belong = {} - self.host_allocator = HostAllocator() - self.event_belong = {} - - self.rlock = threading.RLock() - - def _lazy_init(self): - if self.is_initialized: - return - self.devices = [] - - for i in range(self.num_devices): - req_queue = mp_context.Queue() - ans_queue = mp_context.Queue() - runner = mp_context.Process( - target=_Executor(i).run_forever, - args=(req_queue, ans_queue), - daemon=True, - ) - runner.start() - self.devices.append((req_queue, ans_queue, runner)) - - self.is_initialized = True - - def exec(self, cmd, *args): - with self.rlock: - log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) - - if cmd in Driver.registry: - res = Driver.registry[cmd](self, *args) - else: - res = self.run_on_executor(self.curr_device_idx, cmd, *args) - - log.info("Main process result for %s received: %s", cmd, safe_str(res)) - if res == "ERROR": - raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") - else: - return res - - def run_on_executor(self, device_idx, cmd, *args): - self._lazy_init() - req_queue, ans_queue, _ = self.devices[device_idx] - stream = self.getStream(device_idx) - validate_send_queue_args(cmd, args) - req_queue.put((stream, cmd) + args) - return ans_queue.get() - - registry = {} - - @register(registry) - def hasPrimaryContext(self, device_idx): - return device_idx >= 0 and device_idx < self.num_devices - - @register(registry) - def deviceCount(self, *args): - assert len(args) == 0 - return self.num_devices - - @register(registry) - def getDevice(self): - return self.curr_device_idx - - @register(registry) - def setDevice(self, device_idx): - assert device_idx >= 0 and device_idx < self.num_devices - self.curr_device_idx = device_idx - - @register(registry) - def uncheckedSetDevice(self, *args): - assert len(args) == 1 - self.curr_device_idx = int(args[0]) - - @register(registry) - def exchangeDevice(self, *args): - assert len(args) == 1 - res = self.curr_device_idx - self.curr_device_idx = int(args[0]) - return res - - @register(registry) - def malloc(self, size): - ptr = self.run_on_executor(self.curr_device_idx, "malloc", size) - self.memory_belong[ptr] = self.curr_device_idx - return ptr - - @register(registry) - def free(self, ptr): - device_idx = self.memory_belong.pop(ptr, None) - if device_idx is None: - return False - return self.run_on_executor(device_idx, "free", ptr) - - @register(registry) - def isPinnedPtr(self, ptr): - return self.host_allocator.is_pinned_ptr(ptr) - - @register(registry) - def hostMalloc(self, size): - return self.host_allocator.malloc(size) - - @register(registry) - def hostFree(self, ptr): - return self.host_allocator.free(ptr) - - @register(registry) - def getNewStream(self, device_idx, priority): - return self.run_on_executor(device_idx, "getNewStream", priority) - - @register(registry) - def queryStream(self, stream): - return self.run_on_executor( - stream.device_index, "queryStream", stream.stream_id - ) - - @register(registry) - def getStream(self, device_idx): - return self.curr_streams.get(device_idx, 0) - - @register(registry) - def exchangeStream(self, stream): - stream_id = self.curr_streams.get(stream.device_index, 0) - self.curr_streams[stream.device_index] = stream.stream_id - return stream_id - - @register(registry) - def synchronizeStream(self, stream): - self.run_on_executor(stream.device_index, "synchronizeStream", stream.stream_id) - - @register(registry) - def record(self, event, stream, device_index, flags): - event_ptr = ctypes.cast(event, ctypes.POINTER(ctypes.c_int64)) - # Create event if needed - if event_ptr.contents.value == 0: - event_ptr.contents.value = self.run_on_executor( - stream.device_index, "eventCreateWithFlags", flags - ) - self.event_belong[event_ptr.contents.value] = stream.device_index - - # Record event - self.run_on_executor( - stream.device_index, - "eventRecord", - event_ptr.contents.value, - stream.stream_id, - ) - - @register(registry) - def destroyEvent(self, event, device_index): - self.run_on_executor(device_index, "eventDestroy", event) - self.event_belong.pop(event) - - @register(registry) - def synchronizeEvent(self, event): - self.run_on_executor(self.event_belong[event], "eventSynchronize", event) - - @register(registry) - def queryEvent(self, event): - return self.run_on_executor(self.event_belong[event], "eventQuery", event) - - @register(registry) - def elapsedTime(self, e1, e2, device_index): - return self.run_on_executor(device_index, "eventElapsedTime", e1, e2) - - @register(registry) - def block(self, event, stream): - self.run_on_executor(stream.device_index, "block", event, stream.stream_id) - - -class _Executor: - def __init__(self, id): - self.id = id - self.allocator = DeviceAllocator() - self.stream = 0 - self.event_incr_id = 0 - self.events = {} - - def run_forever(self, req_queue, ans_queue): - # Serve all requests - while True: - # Ignore stream since cpu backend doesn't support asynchronous execution - _, cmd, *args = req_queue.get() - log.info("Worker executing: %s", cmd) - if cmd in _Executor.registry: - res = _Executor.registry[cmd](self, *args) - else: - log.warning("Bad command in worker") - res = "ERROR" - - log.info("Worker answering to: %s", cmd) - ans_queue.put(res) - - registry = {} - - @register(registry) - def malloc(self, size): - return self.allocator.malloc(size) - - @register(registry) - def free(self, ptr): - return self.allocator.free(ptr) - - def _run_op(self, op_name, args, kwargs): - op, _ = torch._C._jit_get_operation(op_name) - args, kwargs = receive_after_sending(self.allocator, args, kwargs) - return op(*args, **kwargs) - - @register(registry) - def run_op(self, op_name, args, kwargs): - self._run_op(op_name, args, kwargs) - - @register(registry) - def get_op_output_shape(self, op_name, args, kwargs): - return self._run_op(op_name, args, kwargs).size() - - @register(registry) - def send_data(self, *args): - assert len(args) == 1 - return OpenRegTensorData.from_meta(self.allocator, args[0]) - - @register(registry) - def recv_data(self, host_tensor, dev_mem): - dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem) - dev_tensor.copy_(host_tensor) - - @register(registry) - def getNewStream(self, priority): - self.stream += 1 - return self.stream - - @register(registry) - def queryStream(self, stream): - return True - - @register(registry) - def synchronizeStream(self, stream): - # no-op - pass - - @register(registry) - def eventCreateWithFlags(self, flags): - self.event_incr_id += 1 - self.events[self.event_incr_id] = [flags, None] - return self.event_incr_id - - @register(registry) - def eventRecord(self, event, stream): - # Only flags == 1 enables timing - if self.events[event][0] == 1: - self.events[event][1] = time.time() * 1000 - return 0 - - @register(registry) - def eventDestroy(self, event): - self.events.pop(event) - - @register(registry) - def eventSynchronize(self, event): - assert self.events.get(event) is not None - return 0 - - @register(registry) - def eventQuery(self, event): - assert self.events.get(event) is not None - return True - - @register(registry) - def eventElapsedTime(self, e1, e2): - time_1 = self.events[e1][1] - time_2 = self.events[e2][1] - assert time_1 is not None and time_2 is not None - return time_2 - time_1 - - @register(registry) - def block(self, event, stream): - # no-op - pass - - -driver = Driver(NUM_DEVICES) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py deleted file mode 100644 index 0f54f2ec4df00..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py +++ /dev/null @@ -1,103 +0,0 @@ -import pprint - -import torch -from torch.utils._pytree import tree_map, tree_map_only - - -class OpenRegTensorMeta: - def __init__(self, tensor, checked=True): - if checked and not tensor.device.type == "openreg": - raise RuntimeError( - "Creating OpenRegTensorMeta is only for Tensors on openreg device" - ) - self.data_ptr = tensor.untyped_storage().data_ptr() - self.size = tensor.size() - self.stride = tensor.stride() - self.storage_offset = tensor.storage_offset() - self.dtype = tensor.dtype - self.nelem_in_bytes = tensor.nelement() * tensor.element_size() - - def __repr__(self): - return ( - f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " - f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" - ) - - -class OpenRegTensorData(torch.Tensor): - @staticmethod - def from_meta(allocator, tensor_meta): - return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) - - -VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} - -VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} - - -def safe_str(args): - def convert(obj): - if isinstance(obj, torch.Tensor): - return str(OpenRegTensorMeta(obj, checked=False)) - else: - return obj - - new_args = tree_map(convert, args) - return pprint.pformat(new_args) - - -def validate_send_queue_args(cmd, args): - def check(obj): - if type(obj) not in VALID_QUEUE_TYPES_OUT: - if ( - cmd == "recv_data" - and type(obj) in [torch.Tensor, OpenRegTensorData] - and obj.device.type == "cpu" - ): - # Only HtoD copy command can send cpu Tensors over - return - raise RuntimeError( - f"Trying to send invalid object through queue: {type(obj)}" - ) - - tree_map(check, args) - - -def prepare_for_sending(args, kwargs): - def convert(obj): - if type(obj) not in VALID_QUEUE_TYPES_IN: - raise RuntimeError( - f"Cannot send object of type {type(obj)} over openreg device pipe." - ) - - if isinstance(obj, torch.Tensor): - return OpenRegTensorMeta(obj) - else: - return obj - - return tree_map(convert, (args, kwargs)) - - -def receive_after_sending(allocator, args, kwargs): - def convert(obj): - if type(obj) not in VALID_QUEUE_TYPES_OUT: - raise RuntimeError( - f"Received invalid object of type {type(obj)} over openreg device pipe." - ) - - if isinstance(obj, OpenRegTensorMeta): - return allocator.tensor_from_meta(obj) - else: - return obj - - return tree_map(convert, (args, kwargs)) - - -def to_device_no_copy(device, args, kwargs): - def safe_to(t): - if device == "meta": - return t.to(device=device) - else: - return torch.empty_like(t, device=device) - - return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp deleted file mode 100644 index 4580629454b76..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "OpenReg.h" - -#include - -#include -#include -#include -#include - -static PyObject* _initExtension(PyObject* self, PyObject* noargs) { - HANDLE_TH_ERRORS - - at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); - - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { - HANDLE_TH_ERRORS - TORCH_CHECK( - THPUtils_checkLong(arg), - "_get_default_generator expects an int, but got ", - THPUtils_typename(arg)); - auto idx = static_cast(THPUtils_unpackLong(arg)); - - return THPGenerator_initDefaultGenerator( - at::globalContext().defaultGenerator( - c10::Device(c10::DeviceType::PrivateUse1, idx))); - - END_HANDLE_TH_ERRORS -} - -static PyMethodDef methods[] = { - {"_init", _initExtension, METH_NOARGS, nullptr}, - {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; - -static struct PyModuleDef openreg_C_module = - {PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods}; - -PyMODINIT_FUNC PyInit__C(void) { - PyObject* mod = PyModule_Create(&openreg_C_module); - - py::object openreg_mod = py::module_::import("pytorch_openreg"); - // Only borrowed from the python side! - openreg::set_impl_factory(openreg_mod.attr("impl_factory").ptr()); - - return mod; -} diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h deleted file mode 100644 index a04248f2e5029..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include - -namespace openreg { - -using openreg_ptr_t = uint64_t; - -void set_impl_factory(PyObject* factory); -py::function get_method(const char* name); - -static constexpr char kFreeMethod[] = "free"; -static constexpr char kHostFreeMethod[] = "hostFree"; - -template -static void ReportAndDelete(void* ptr) { - if (!ptr || !Py_IsInitialized()) { - return; - } - - py::gil_scoped_acquire acquire; - - PyObject *type = nullptr, *value = nullptr, *traceback = nullptr; - // Always stash, this will be a no-op if there is no error - PyErr_Fetch(&type, &value, &traceback); - - TORCH_CHECK( - get_method(name)(reinterpret_cast(ptr)).cast(), - "Failed to free memory pointer at ", - ptr); - - // If that user code raised an error, just print it without raising it - if (PyErr_Occurred()) { - PyErr_Print(); - } - - // Restore the original error - PyErr_Restore(type, value, traceback); -} - -#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ - FOR_SERIALIZATION, FOR_DESERIALIZATION) \ - static int register_serialization() { \ - torch::jit::TensorBackendMetaRegistry( \ - c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ - return 0; \ - } \ - static const int _temp = register_serialization(); - -} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp deleted file mode 100644 index a87b378fb95c8..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp +++ /dev/null @@ -1,350 +0,0 @@ -#include "OpenReg.h" - -#include -#include -#include - -#include -#include -#include - -namespace openreg { -namespace { - -// Python factory function where real implementations can be found -PyObject* py_factory; - -struct HostAllocator final : at::Allocator { - HostAllocator() = default; - - at::DataPtr allocate(size_t nbytes) override { - py::gil_scoped_acquire acquire; - void* data = nullptr; - if (nbytes > 0) { - data = reinterpret_cast( - get_method("hostMalloc")(nbytes).cast()); - TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); - } - return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - py::gil_scoped_acquire acquire; - get_method("hostCopyData")( - reinterpret_cast(dest), - reinterpret_cast(src), - count); - } -}; - -static HostAllocator global_host_alloc; - -static c10::DeviceIndex device_count() { - py::gil_scoped_acquire acquire; - return get_method("deviceCount")().cast(); -} - -static c10::DeviceIndex current_device_idx() { - py::gil_scoped_acquire acquire; - return get_method("getDevice")().cast(); -} - -class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { - public: - OpenRegGeneratorImpl(c10::DeviceIndex device_index) { - device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); - key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); - } - ~OpenRegGeneratorImpl() override = default; -}; - -static at::Generator make_openreg_generator(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -// Default, global generators, one per device. -static std::vector default_generators; - -struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { - OpenRegHooksInterface() {}; - ~OpenRegHooksInterface() override = default; - - bool hasPrimaryContext(c10::DeviceIndex device_index) const override { - py::gil_scoped_acquire acquire; - return get_method("hasPrimaryContext")(device_index).cast(); - } - - at::Allocator* getPinnedMemoryAllocator() const override { - return &global_host_alloc; - } - - bool isPinnedPtr(const void* data) const override { - py::gil_scoped_acquire acquire; - return get_method("isPinnedPtr")(reinterpret_cast(data)) - .cast(); - } - - const at::Generator& getDefaultGenerator( - c10::DeviceIndex device_index) const override { - static bool flag [[maybe_unused]] = []() { - auto deivce_nums = device_count(); - default_generators.resize(deivce_nums); - for (auto i = 0; i < deivce_nums; i++) { - default_generators[i] = make_openreg_generator(i); - default_generators[i].seed(); - } - return true; - }(); - - c10::DeviceIndex idx = device_index; - if (idx == -1) { - idx = current_device_idx(); - } else { - TORCH_CHECK(idx >= 0 && idx < device_count()); - } - return default_generators[idx]; - } - - at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { - return make_openreg_generator(device_index); - } -}; - -static bool register_hook_flag [[maybe_unused]] = []() { - at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); - - return true; -}(); - -// Device guard registration -struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { - static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; - - OpenRegGuardImpl() = default; - explicit OpenRegGuardImpl(c10::DeviceType t) { - TORCH_INTERNAL_ASSERT(t == static_type); - } - - /** - * Return the type of device managed by this guard implementation. - */ - c10::DeviceType type() const override { - return static_type; - } - - /** - * Set the current device to Device, and return the previous c10::Device. - */ - c10::Device exchangeDevice(c10::Device d) const override { - TORCH_INTERNAL_ASSERT(d.is_privateuseone()); - py::gil_scoped_acquire acquire; - auto old_device_index = - get_method("exchangeDevice")(d.index()).cast(); - return c10::Device(static_type, old_device_index); - } - - /** - * Get the current device. - */ - c10::Device getDevice() const override { - return c10::Device(static_type, current_device_idx()); - } - - /** - * Set the current device to c10::Device. - */ - void setDevice(c10::Device d) const override { - TORCH_INTERNAL_ASSERT(d.is_privateuseone()); - py::gil_scoped_acquire acquire; - auto device = get_method("setDevice")(d.index()); - } - - /** - * Set the current device to c10::Device, without checking for errors - * (so, e.g., this can be called from a destructor). - */ - void uncheckedSetDevice(c10::Device d) const noexcept override { - py::gil_scoped_acquire acquire; - auto device = get_method("uncheckedSetDevice")(d.index()); - } - - /** - * Get the current stream for a given device. - */ - c10::Stream getStream(c10::Device d) const noexcept override { - py::gil_scoped_acquire acquire; - auto stream_id = get_method("getStream")(d.index()).cast(); - return c10::Stream(c10::Stream::UNSAFE, d, stream_id); - } - - /** - * Get the default stream for a given device. - */ - c10::Stream getDefaultStream(c10::Device d) const override { - py::gil_scoped_acquire acquire; - return get_method("getDefaultStream")(d.index()).cast(); - } - - /** - * Get a stream from the global pool for a given device. - */ - c10::Stream getStreamFromGlobalPool( - c10::Device d, - bool isHighPriority = false) const override { - py::gil_scoped_acquire acquire; - return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority) - .cast(); - } - - /** - * Return a new stream for a given device and priority. The stream will be - * copied and shared around, device backend should be able to correctly handle - * the lifetime of the stream. - */ - c10::Stream getNewStream(c10::Device d, int priority = 0) const override { - py::gil_scoped_acquire acquire; - auto stream_id = - get_method("getNewStream")(d.index(), priority).cast(); - return c10::Stream(c10::Stream::UNSAFE, d, stream_id); - } - - /** - * Set a stream to be the thread local current stream for its device. - * Return the previous stream for that device. You are NOT required - * to set the current device to match the device of this stream. - */ - c10::Stream exchangeStream(c10::Stream s) const noexcept override { - py::gil_scoped_acquire acquire; - auto stream_id = get_method("exchangeStream")(s).cast(); - return c10::Stream(c10::Stream::UNSAFE, s.device(), stream_id); - } - - /** - * Destroys the given event. - */ - void destroyEvent(void* event, const c10::DeviceIndex device_index) - const noexcept override { - py::gil_scoped_acquire acquire; - get_method("destroyEvent")((int64_t)event, device_index); - } - - /** - * Increments the event's version and enqueues a job with this version - * in the stream's work queue. When the stream process that job - * it notifies all streams waiting on / blocked by that version of the - * event to continue and marks that version as recorded. - * */ - void record( - void** event, - const c10::Stream& stream, - const c10::DeviceIndex device_index, - const c10::EventFlag flag) const override { - py::gil_scoped_acquire acquire; - get_method("record")((int64_t)event, stream, device_index, (int64_t)flag); - } - - /** - * Does nothing if the event has not been scheduled to be recorded. - * If the event was previously enqueued to be recorded, a command - * to wait for the version of the event that exists at the time of this call - * is inserted in the stream's work queue. - * When the stream reaches this command it will stop processing - * additional commands until that version of the event is marked as recorded. - */ - void block(void* event, const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("block")((int64_t)event, stream); - } - - /** - * Returns true if (and only if) - * (1) the event has never been scheduled to be recorded - * (2) the current version is marked as recorded. - * Returns false otherwise. - */ - bool queryEvent(void* event) const override { - py::gil_scoped_acquire acquire; - return get_method("queryEvent")((int64_t)event).cast(); - } - - /** - * Get the number of devices. WARNING: This is REQUIRED to not raise - * an exception. If there is some sort of problem, e.g., driver error, - * you should report that there are zero available devices. - */ - c10::DeviceIndex deviceCount() const noexcept override { - return device_count(); - } - /** - * Return true if all the work previously enqueued on the stream for - * asynchronous execution has completed running on the device. - */ - bool queryStream(const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - return get_method("queryStream")(stream).cast(); - } - - /** - * Wait (by blocking the calling thread) until all the work previously - * enqueued on the stream has completed running on the device. - */ - virtual void synchronizeStream(const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("synchronizeStream")(stream); - } - - /** - * Wait (by blocking the calling thread) until all the work previously - * recorded on the event has completed running on the device. - */ - void synchronizeEvent(void* event) const override { - py::gil_scoped_acquire acquire; - get_method("synchronizeEvent")((int64_t)event); - } - - /** - * Ensure the caching allocator (if any) is aware that the given DataPtr is - * being used on the given stream, and that it should thus avoid recycling the - * DataPtr until all work on that stream is done. - */ - void recordDataPtrOnStream( - const c10::DataPtr& data_ptr, - const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("recordDataPtrOnStream")(data_ptr, stream); - } - - /** - * Fetch the elapsed time between two recorded events. - */ - double elapsedTime( - void* event1, - void* event2, - const c10::DeviceIndex device_index) const override { - py::gil_scoped_acquire acquire; - return get_method("elapsedTime")( - (int64_t)event1, (int64_t)event2, device_index) - .cast(); - } -}; - -// Register our device guard -C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); - -} // namespace - -// Setter for the python dictionary with implementations -void set_impl_factory(PyObject* factory) { - py_factory = factory; -} - -py::function get_method(const char* name) { - auto factory = py::cast(py_factory); - return factory(name); -} - -} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp deleted file mode 100644 index 4d9bde0601183..0000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp +++ /dev/null @@ -1,418 +0,0 @@ -#include "OpenReg.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include - -namespace openreg { -namespace { - -struct OpenRegAllocator final : at::Allocator { - OpenRegAllocator() = default; - - at::DataPtr allocate(size_t nbytes) override { - py::gil_scoped_acquire acquire; - auto curr_device_idx = get_method("getDevice")().cast(); - auto curr_device = - c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx); - void* data = nullptr; - if (nbytes > 0) { - data = reinterpret_cast( - get_method("malloc")(nbytes).cast()); - TORCH_CHECK( - data, "Failed to allocator ", nbytes, " bytes on openreg device."); - } - return {data, data, &ReportAndDelete, curr_device}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - py::gil_scoped_acquire acquire; - get_method("copy_data")( - reinterpret_cast(dest), - reinterpret_cast(src), - count); - } -}; - -static OpenRegAllocator global_openreg_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); - -// Empty op needs C++ code and cannot be handled by python side fallback -at::Tensor empty_openreg( - c10::IntArrayRef size, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt, - std::optional memory_format_opt) { - const auto device = c10::device_or_default(device_opt); - const auto dtype = c10::dtype_or_default(dtype_opt); - TORCH_CHECK(device.is_privateuseone()); - TORCH_CHECK( - c10::layout_or_default(layout_opt) == c10::Layout::Strided, - "Non strided layout not supported"); - TORCH_CHECK( - !c10::pinned_memory_or_default(pin_memory_opt), - "Pin memory can only be on CPU"); - const c10::DeviceGuard device_guard(device); - constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); - return at::detail::empty_generic( - size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt); -} - -at::Tensor empty_strided_openreg( - c10::IntArrayRef size, - c10::IntArrayRef stride, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt) { - const auto device = c10::device_or_default(device_opt); - const auto dtype = c10::dtype_or_default(dtype_opt); - TORCH_CHECK(device.is_privateuseone()); - TORCH_CHECK( - c10::layout_or_default(layout_opt) == c10::Layout::Strided, - "Non strided layout not supported"); - TORCH_CHECK( - !c10::pinned_memory_or_default(pin_memory_opt), - "Pin memory can only be on CPU"); - const c10::DeviceGuard device_guard(device); - constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); - return at::detail::empty_strided_generic( - size, stride, &global_openreg_alloc, pu1_dks, dtype); -} - -at::Tensor as_strided_openreg( - const at::Tensor& self, - c10::IntArrayRef size, - c10::IntArrayRef stride, - std::optional storage_offset_) { - // Metadata-only change so we re-use the cpu impl - return at::cpu::as_strided(self, size, stride, storage_offset_); -} - -const at::Tensor& resize__openreg( - const at::Tensor& self, - c10::SymIntArrayRef size, - ::std::optional memory_format) { - return at::native::resize_( - self, C10_AS_INTARRAYREF_SLOW(size), memory_format); -} - -at::Tensor& set_source_Storage_storage_offsetset_openreg( - at::Tensor& result, - at::Storage storage, - int64_t storage_offset, - c10::IntArrayRef size, - c10::IntArrayRef stride) { - return at::cpu::set_(result, storage, storage_offset, size, stride); -} - -std::tuple -custom_scaled_dot_product_fused_attention_overrideable( - const at::Tensor & query, - const at::Tensor & key, - const at::Tensor & value, - const std::optional & attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_v = value.size(3); - const int64_t max_seqlen_q = query.size(2); - const int64_t max_seqlen_kv = key.size(2); - - auto opts = query.options(); - auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); - auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, - opts.dtype(at::kFloat)); - auto philox_seed = at::empty({}, at::dtype(at::kLong)); - auto philox_offset = at::empty({}, at::dtype(at::kLong)); - - return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask); -} - -std::tuple -custom_scaled_dot_product_fused_attention_overrideable_backward( - const at::Tensor & grad_out, - const at::Tensor & query, - const at::Tensor & key, - const at::Tensor & value, - const at::Tensor & attn_bias, - std::array grad_input_mask, - const at::Tensor & out, - const at::Tensor & logsumexp, - const at::Tensor & cum_seq_q, - const at::Tensor & cum_seq_k, - int64_t max_q, - int64_t max_k, - double dropout_p, - bool is_causal, - const at::Tensor & philox_seed, - const at::Tensor & philox_offset, - std::optional scale) { - return std::tuple( - at::empty_like(query), - at::empty_like(key), - at::empty_like(value), - at::empty_like(attn_bias)); -} -} - -// Using the simplest way to obtain continuous Tensor data and process it. -// This is a demo for using operand API, and you can add more complex logic -// for input and output tensor based on your custom device kernel. -void abs_kernel(at::TensorIteratorBase& iter) { - // Abs only have a input tensor and a output tensor. - auto& output_operand = iter.operand(0); - auto& input_operand = iter.operand(1); - auto& output_tensor_base = output_operand.tensor_base(); - auto& input_tensor_base = input_operand.tensor_base(); - TORCH_CHECK(!input_operand.original_tensor_base().defined(), - "input original tensor is defined."); - TORCH_CHECK(!output_operand.original_tensor_base().defined(), - "output original tensor is defined."); - // For easy test, only accept contiguous input tensor for calculate. - auto memory_format = input_tensor_base.suggest_memory_format(); - TORCH_CHECK(input_tensor_base.is_contiguous(memory_format), - "Input tensor need be contiguous."); - // Add necessary restrictions to ensure the security of the demo. - TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(), - "Intput and output tensor size are not equal."); - // Common dtype is calculate in TensorIteratorBase. - TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float, - "Only support float type.") - // Using for loop for abs calculate. - auto abs_function = [](float* output_ptr, const float* input_ptr, - const int64_t NUM) { - for (int64_t i = 0; i < NUM; ++i) { - *(output_ptr + i) = std::abs(*(input_ptr + i)); - } - }; - // To simplify the logic of the test demo code, - // we only use contiguous tensor to calculate on device side. - // And using input tensor memory format. - if (iter.is_contiguous()) { - // Add for will_resize flag check. You can convert to differernt - // tensor memory format when will_resize is True. - // If TensorIteratorConfig resize_outputs_ flag is true, and there are two - // situations: - // 1) Out tensor is undefined, and TensorIterator set will_resize to true; - // 2) Out tensor is defined and tensor size is not equal to input tensor size; - // TensorIterator set will_resize to true, and call set_output_raw_strided - // to resize output tensor. - // When output operand will_resize flag is ture, dummy - // device can convert tensor to dummy device preferred memory format. - // Here we don't convert tensor memory format, because it will become complex - // when dummy device want keep same memory format for training network. - TORCH_CHECK(output_operand.will_resize, - "output operand will_resize flag need be True."); - abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); - } else { - // Stride copy is not support for foo device, using cpu device instead. - // For abs op, the last situation is: output tensor is not contiguous with - // operand will_resize is False. - TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True."); - // Get a contiguous tensor with input memory format. - at::Tensor output = at::empty(output_tensor_base.sizes(), - input_tensor_base.options() - .memory_format(memory_format)); - // For structured op which inheried from TensorIteratorBase, maybe you need to - // call set_output_raw_strided function to update output stored in op sturctured. - // abs op is no need to do this. - output_operand.exchange_tensor(c10::MaybeOwned::owned(std::in_place, output)); - abs_function((float*)output_operand.tensor_base().mutable_data_ptr(), - (float*)iter.data_ptr(1), iter.numel()); - // Copy tensor base to original tensor base, and keep same scalar type and - // stride with cpu and gpu. - if (output_operand.original_tensor_base().defined() && - !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) { - output_operand.original_tensor().copy_(output_operand.tensor()); - output_operand.restore_original_tensor(); - } - } -} - -int64_t _fused_sdp_choice_privateuse1( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_mask, - double dropout_p, - bool is_causal, - std::optional scale, - bool enable_gqa) { - auto backend = sdp::SDPBackend::overrideable; - return static_cast(backend); -} - -void quantize_tensor_per_tensor_affine_privateuse1( - const at::Tensor& rtensor, - at::Tensor& qtensor, - double scale, - int64_t zero_point) { - // Just test the process, so do nothing -} - -struct CustomAutogradFnReturnsSelf - : public torch::autograd::Function { - static at::Tensor forward( - torch::autograd::AutogradContext* ctx, - at::Tensor self) { - return self; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -struct CustomAutogradFnAliasing - : public torch::autograd::Function { - static at::Tensor forward( - torch::autograd::AutogradContext* ctx, - at::Tensor self) { - return self.view_symint(self.sym_sizes()); - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { - return CustomAutogradFnReturnsSelf::apply(x); -} - -at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { - return CustomAutogradFnAliasing::apply(x); -} - -/* Notes: - * - * OpenReg is currently designed to simulate device memory through multiple - * subprocesses on purpose to ensure we don't mistakenly poke at the "device's - * memory" from the main process. And be able to simulate the same thing that - * happens with other accelerators: any metadata-only change is cpu-only - * (main process), any data change must go through to the device (other process) - * and any data transfer between the two is expensive (serializing the whole - * Tensor). - * - * Currently, for the efficiency of IPC, most operations are to pass the Tensor - * metadata, and only a small number of operations involving copy will serialize - * and pass the Tensor body by custom pickler provided by torch.multiprocess. - * - * Therefore, in principle, only operations related to Metadata modification can - * be directly implemented at the C++ level and registered in PrivateUse1; but - * if memory access is involved, the relevant operations must be implemented at - * the Python level, otherwise invalid memory access will result. - */ - -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("empty.memory_format", empty_openreg); - m.impl("empty_strided", empty_strided_openreg); - m.impl("as_strided", as_strided_openreg); - m.impl("resize_", resize__openreg); - m.impl("set_.source_Storage", at::native::set_); - m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg); - m.impl("quantize_per_tensor", at::native::quantize_per_tensor); - m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1); - m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable); - m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward); -} - -struct OpenRegBackendMeta : public c10::BackendMeta { - OpenRegBackendMeta(int version_number, int format_number) - : version_number_(version_number), format_number_(format_number) {} - - int version_number_{-1}; - int format_number_{-1}; -}; - -void for_serialization( - const at::Tensor& t, - std::unordered_map& m) { - auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); - - if (meta_ptr != nullptr) { - auto o_meta_ptr = dynamic_cast(meta_ptr); - if (o_meta_ptr->version_number_ == 1) { - m["version_number"] = true; - } - if (o_meta_ptr->format_number_ == 29) { - m["format_number"] = true; - } - } -} - -void for_deserialization( - const at::Tensor& t, - std::unordered_map& m) { - int version_number{-1}; - int format_number{-1}; - - if (m.find("version_number") != m.end()) { - version_number = 1; - } - if (m.find("format_number") != m.end()) { - format_number = 29; - } - - c10::intrusive_ptr meta{std::unique_ptr( - new OpenRegBackendMeta(version_number, format_number))}; - t.unsafeGetTensorImpl()->set_backend_meta(meta); -} - -REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) -} // namespace openreg - -namespace at::native { -REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel); -REGISTER_PRIVATEUSE1_DISPATCH( - quantize_tensor_per_tensor_affine_stub, - &openreg::quantize_tensor_per_tensor_affine_privateuse1); -REGISTER_PRIVATEUSE1_DISPATCH( - _fused_sdp_choice_stub, - &openreg::_fused_sdp_choice_privateuse1); -} // namespace at::native - -TORCH_LIBRARY(openreg, m) { - m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); - m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); -} - -TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { - m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing); - m.impl( - "custom_autograd_fn_returns_self", - &openreg::custom_autograd_fn_returns_self); -} diff --git a/test/cpp_extensions/open_registration_extension/setup.py b/test/cpp_extensions/open_registration_extension/setup.py deleted file mode 100644 index fa8c1308c6c52..0000000000000 --- a/test/cpp_extensions/open_registration_extension/setup.py +++ /dev/null @@ -1,78 +0,0 @@ -import distutils.command.clean -import os -import platform -import shutil -import sys -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -PACKAGE_NAME = "pytorch_openreg" -version = 1.0 - -ROOT_DIR = Path(__file__).absolute().parent -CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove pytorch_openreg extension - for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"): - path.unlink() - # Remove build directory - build_dirs = [ - ROOT_DIR / "build", - ] - for path in build_dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -if __name__ == "__main__": - if sys.platform == "win32": - vc_version = os.getenv("VCToolsVersion", "") - if vc_version.startswith("14.16."): - CXX_FLAGS = ["/sdl"] - else: - CXX_FLAGS = ["/sdl", "/permissive-"] - elif platform.machine() == "s390x": - # no -Werror on s390x due to newer compiler - CXX_FLAGS = {"cxx": ["-g", "-Wall"]} - else: - CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]} - - sources = list(CSRS_DIR.glob("*.cpp")) - - # Note that we always compile with debug info - ext_modules = [ - CppExtension( - name="pytorch_openreg._C", - sources=sorted(str(s) for s in sources), - include_dirs=[CSRS_DIR], - extra_compile_args=CXX_FLAGS, - ) - ] - - setup( - name=PACKAGE_NAME, - version=version, - author="PyTorch Core Team", - description="Example for PyTorch out of tree registration", - packages=find_packages(exclude=("test",)), - package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=ext_modules, - python_requires=">=3.8", - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - ) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt new file mode 100644 index 0000000000000..73163b8cb1ae8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + +include(GNUInstallDirs) +include(CheckCXXCompilerFlag) +include(CMakeDependentOption) + +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/") + +set(LINUX TRUE) +set(CMAKE_INSTALL_MESSAGE NEVER) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(CMAKE_INSTALL_LIBDIR lib) + +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) + +set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) +find_package(Torch REQUIRED) +include_directories(${PYTORCH_INSTALL_DIR}/include) + +if(DEFINED PYTHON_INCLUDE_DIR) + include_directories(${PYTHON_INCLUDE_DIR}) +else() + message(FATAL_ERROR "Cannot find Python directory") +endif() + +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) +add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) +add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md new file mode 100644 index 0000000000000..e59013cea4407 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md @@ -0,0 +1,177 @@ +# PyTorch OpenReg + +## Background + +The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem. + +**Note:** + +The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**. + +### Purpose + +- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD. +- **Integration Example**: To serve as a reference example for new backend integration. +- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code. + +### Design Principles + +- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities. +- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch. + +## Directory Structure + +```shell +torch_openreg/ +├── CMakeLists.txt +├── csrc +│ ├── aten +│ │ ├── native +│ │ │ ├── Extra.cpp +│ │ │ ├── Minimal.cpp +│ │ │ └── ... +│ │ ├── OpenRegExtra.cpp +│ │ └── OpenRegMinimal.cpp +│ ├── CMakeLists.txt +│ └── runtime +│ ├── OpenRegDeviceAllocator.cpp +│ ├── OpenRegDeviceAllocator.h +│ ├── OpenRegFunctions.cpp +│ ├── OpenRegFunctions.h +│ ├── OpenRegGenerator.cpp +│ ├── OpenRegGenerator.h +│ ├── OpenRegGuard.cpp +│ ├── OpenRegGuard.h +│ ├── OpenRegHooks.cpp +│ ├── OpenRegHooks.h +│ ├── OpenRegHostAllocator.cpp +│ ├── OpenRegHostAllocator.h +│ └── ... +├── README.md +├── setup.py +├── third_party +│ └── openreg +└── torch_openreg + ├── csrc + │ ├── CMakeLists.txt + │ ├── Module.cpp + │ └── stub.c + ├── __init__.py + └── openreg + ├── __init__.py + └── random.py +``` + +**Dependencies**: + +```mermaid +graph LR + A[Python] + B[_C.so] + C[libtorch_bindings.so] + D[libtorch_openreg.so] + E[libopenreg.so] + + A --> B --> C --> D --> E +``` + +- `_C.so`: torch\_openreg/csrc/stub.c +- `libtorch_bindings.so`: torch\_openreg/csrc/\*.cpp +- `libtorch_openreg.so`: csrc +- `libopenreg.so`: third\_party/openreg + +**Key Directories**: + +- `csrc/`: Core device implementation, including operator registration, runtime, etc. + - `csrc/aten/`: Operator registration + - `csrc/aten/native/`: Specific operator implementations for the OpenReg device. + - `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion). + - `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators. + - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. +- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU. +- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings). + - `torch_openreg/csrc/`: Python C++ binding code. + - `torch_openreg/openreg/`: Python API. + +## Currently Implemented Features + +### Operator Registration + +- Operator Implementation + + - `TORCH_LIBRARY` form + - Registering a specific operator for an existing schema: See `empty.memory_format` + - Registering an operator with a custom schema + - Extending an existing namespace: (TODO) + - Custom namespace: See `custom_autograd_fn_returns_self` + - Autograd: See `custom_autograd_fn_returns_self` + - STUB form: See `abs_stub` + + - Fallback + - Global Fallback: See `wrapper_cpu_fallback` + - Per-operator Fallback: (TODO) + + - AMP (TODO) + +### Memory Management + +- Device Memory Management (TODO) +- Host Memory Management (TODO) + +### Custom Storage + +- Adding custom device descriptions (TODO) +- Serialization support (TODO) + +### Autoload + +- (TODO) + +... + +## Installation and Usage + +### Installation + +```python +pip3 install -r requirements.txt + +python setup.py develop/install +``` + +### Usage Example + +After installation, you can use the `openreg` device in Python just like any other regular device. + +```python +import torch +import torch_openreg + +if not torch.openreg.is_available(): + print("OpenReg backend is not available in this build.") + exit() + +print("OpenReg backend is available!") + +device = torch.device("openreg") + +try: + x = torch.tensor([[1., 2.], [3., 4.]], device=device) + y = x + 2 + print("Result y:\n", y) + print(f"Device of y: {y.device}") + + z = y.cpu() + print("Result z:\n", z) + print(f"Device of z: {z.device}") + +except Exception as e: + print(f"\nAn error occurred: {e}") +``` + +## Future Plans + +- **Enhance Features**: AMP, memory management, generators, distributed computing, etc. (to reiterate, the fundamental goal is to verify the integration mechanism). +- **Improve Tests**: Add more test cases related to the integration mechanism. +- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation. +- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync. diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 0000000000000..077f4cf3b6404 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LIBRARY_NAME torch_openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu) +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp new file mode 100644 index 0000000000000..3d8525697cc8c --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp @@ -0,0 +1,138 @@ +#include "native/Extra.h" + +#include +#include + +#include + +namespace at::openreg { + +at::Tensor wrapper_quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor_openreg( + self, scale, zero_point, dtype); +} + +int64_t wrapper__fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return at::native::_fused_sdp_choice_openreg( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +wrapper__scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + return at::native::_scaled_dot_product_fused_attention_overrideable_openreg( + query, + key, + value, + attn_bias, + dropout_p, + is_causal, + return_debug_mask, + scale); +} + +std::tuple +wrapper_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return at::native:: + _scaled_dot_product_fused_attention_overrideable_backward_openreg( + grad_out, + query, + key, + value, + attn_bias, + grad_input_mask, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale); +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); + m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); + m.impl( + "_scaled_dot_product_fused_attention_overrideable", + &wrapper__scaled_dot_product_fused_attention_overrideable); + m.impl( + "_scaled_dot_product_fused_attention_overrideable_backward", + &wrapper_scaled_dot_product_fused_attention_overrideable_backward); +} + +} // namespace at::openreg + +namespace at::openreg { +TORCH_LIBRARY(openreg, m) { + m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} + +TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { + m.impl( + "custom_autograd_fn_returns_self", + &at::native::custom_autograd_fn_returns_self); + m.impl( + "custom_autograd_fn_aliasing", &at::native::custom_autograd_fn_aliasing); +} +} // namespace at::openreg + +namespace at::native { +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_openreg); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &quantize_tensor_per_tensor_affine_stub_openreg); +REGISTER_PRIVATEUSE1_DISPATCH( + _fused_sdp_choice_stub, + &_fused_sdp_choice_openreg); +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp new file mode 100644 index 0000000000000..fe75cdaea8b2a --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp @@ -0,0 +1,128 @@ +#include "native/Minimal.h" + +#include +#include + +#include + +namespace at::openreg { + +at::Tensor wrapper_empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + return at::native::empty_memory_format_openreg( + size, + dtype_opt, + layout_opt, + device_opt, + pin_memory_opt, + memory_format_opt); +} + +at::Tensor wrapper_empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + return at::native::empty_strided_openreg( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +at::Tensor wrapper_as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + return at::native::as_strided_openreg(self, size, stride, storage_offset); +} + +const at::Tensor& wrapper_resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_openreg_(self, size, memory_format); +} + +at::Tensor wrapper__reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias_openreg(self, size, stride); +} + +at::Tensor wrapper__copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + return at::native::_copy_from_openreg(self, dst, non_blocking); +} + +at::Tensor wrapper__copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + return at::native::_copy_from_and_resize_openreg(self, dst); +} + +at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) { + return at::native::_local_scalar_dense_openreg(self); +} + +at::Tensor& wrapper_set_source_Tensor_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::set_source_Tensor_openreg_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::set_source_Storage_openreg_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_storage_offsetset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::native::set_source_Storage_storage_offset_openreg_( + result, storage, storage_offset, size, stride); +} + +at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) { + return at::native::view_openreg(self, size); +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", wrapper_empty_memory_format); + m.impl("empty_strided", wrapper_empty_strided); + m.impl("as_strided", wrapper_as_strided); + m.impl("resize_", wrapper_resize_); + m.impl("_reshape_alias", wrapper__reshape_alias); + m.impl("_copy_from", wrapper__copy_from); + m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize); + m.impl("_local_scalar_dense", wrapper__local_scalar_densor); + m.impl("set_.source_Tensor", wrapper_set_source_Tensor_); + m.impl("set_.source_Storage", wrapper_set_source_Storage_); + m.impl( + "set_.source_Storage_storage_offset", + wrapper_set_source_Storage_storage_offsetset_); + m.impl("view", wrapper_view); +} + +void wrapper_cpu_fallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + at::native::cpu_fallback_openreg(op, stack); +} + +TORCH_LIBRARY_IMPL(_, PrivateUse1, m) { + m.fallback( + torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>()); +} + +} // namespace at::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h new file mode 100644 index 0000000000000..a706137fe852d --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +namespace at::native { + +class MemoryGuard { + public: + explicit MemoryGuard(const torch::jit::Stack& stack) { + for (const c10::IValue& ivalue : stack) { + find_and_unprotect_tensors(ivalue); + } + } + + template + explicit MemoryGuard(const Args&... args) { + (handler(args), ...); + } + + ~MemoryGuard() { + for (void* ptr : unprotected_pointers_) { + orMemoryProtect(ptr); + } + } + + MemoryGuard(const MemoryGuard&) = delete; + MemoryGuard& operator=(const MemoryGuard&) = delete; + MemoryGuard(MemoryGuard&&) = delete; + MemoryGuard& operator=(MemoryGuard&&) = delete; + + private: + void find_and_unprotect_tensors(const c10::IValue& ivalue) { + if (ivalue.isTensor()) { + unprotect_if_needed(ivalue.toTensor()); + } else if (ivalue.isTensorList()) { + for (const at::Tensor& tensor : ivalue.toTensorList()) { + unprotect_if_needed(tensor); + } + } else if (ivalue.isList()) { + for (const c10::IValue& element : ivalue.toListRef()) { + find_and_unprotect_tensors(element); + } + } else if (ivalue.isGenericDict()) { + for (const auto& pair : ivalue.toGenericDict()) { + find_and_unprotect_tensors(pair.key()); + find_and_unprotect_tensors(pair.value()); + } + } + } + + void unprotect_if_needed(const at::Tensor& tensor) { + if (!tensor.defined() || !tensor.has_storage()) { + return; + } + + void* ptr = tensor.data_ptr(); + orPointerAttributes attr; + + if (orPointerGetAttributes(&attr, ptr) == orSuccess) { + if (attr.type == orMemoryTypeDevice) { + if (unprotected_pointers_.find(attr.pointer) == + unprotected_pointers_.end()) { + orMemoryUnprotect(attr.pointer); + unprotected_pointers_.insert(attr.pointer); + } + } + } + } + + template + void handler(const T& x) { + if constexpr (std::is_same_v, at::Tensor>) { + unprotect_if_needed(x); + } + } + + std::set unprotected_pointers_; +}; + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp new file mode 100644 index 0000000000000..741d148035393 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp @@ -0,0 +1,238 @@ +#include "Extra.h" + +namespace at::native { + +at::Tensor quantize_per_tensor_openreg( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor(self, scale, zero_point, dtype); +} + +int64_t _fused_sdp_choice_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + auto backend = sdp::SDPBackend::overrideable; + return static_cast(backend); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_q = query.size(2); + const int64_t max_seqlen_kv = key.size(2); + + auto opts = query.options(); + auto output = + at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); + auto logsumexp = + at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto debug_attn_mask = at::empty( + {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, + opts.dtype(at::kFloat)); + auto philox_seed = at::empty({}, at::dtype(at::kLong)); + auto philox_offset = at::empty({}, at::dtype(at::kLong)); + + return std::make_tuple( + output, + logsumexp, + at::Tensor(), + at::Tensor(), + max_seqlen_q, + max_seqlen_kv, + philox_seed, + philox_offset, + debug_attn_mask); +} + +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward_openreg( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return std::tuple( + at::empty_like(query), + at::empty_like(key), + at::empty_like(value), + at::empty_like(attn_bias)); +} + +} // namespace at::native + +namespace at::native { + +void abs_kernel_openreg(at::TensorIteratorBase& iter) { + // Abs only have a input tensor and a output tensor. + auto& output_operand = iter.operand(0); + auto& input_operand = iter.operand(1); + auto& output_tensor_base = output_operand.tensor_base(); + auto& input_tensor_base = input_operand.tensor_base(); + TORCH_CHECK( + !input_operand.original_tensor_base().defined(), + "input original tensor is defined."); + TORCH_CHECK( + !output_operand.original_tensor_base().defined(), + "output original tensor is defined."); + // For easy test, only accept contiguous input tensor for calculate. + auto memory_format = input_tensor_base.suggest_memory_format(); + TORCH_CHECK( + input_tensor_base.is_contiguous(memory_format), + "Input tensor need be contiguous."); + // Add necessary restrictions to ensure the security of the demo. + TORCH_CHECK( + input_tensor_base.sizes() == output_tensor_base.sizes(), + "Intput and output tensor size are not equal."); + // Common dtype is calculate in TensorIteratorBase. + TORCH_CHECK( + iter.common_dtype() == at::ScalarType::Float, "Only support float type.") + // Using for loop for abs calculate. + auto abs_function = + [](float* output_ptr, const float* input_ptr, const int64_t NUM) { + for (int64_t i = 0; i < NUM; ++i) { + *(output_ptr + i) = std::abs(*(input_ptr + i)); + } + }; + // To simplify the logic of the test demo code, + // we only use contiguous tensor to calculate on device side. + // And using input tensor memory format. + if (iter.is_contiguous()) { + // Add for will_resize flag check. You can convert to differernt + // tensor memory format when will_resize is True. + // If TensorIteratorConfig resize_outputs_ flag is true, and there are two + // situations: + // 1) Out tensor is undefined, and TensorIterator set will_resize to true; + // 2) Out tensor is defined and tensor size is not equal to input tensor + // size; + // TensorIterator set will_resize to true, and call + // set_output_raw_strided to resize output tensor. + // When output operand will_resize flag is ture, dummy + // device can convert tensor to dummy device preferred memory format. + // Here we don't convert tensor memory format, because it will become + // complex when dummy device want keep same memory format for training + // network. + TORCH_CHECK( + output_operand.will_resize, + "output operand will_resize flag need be True."); + abs_function( + (float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); + } else { + // Stride copy is not support for foo device, using cpu device instead. + // For abs op, the last situation is: output tensor is not contiguous with + // operand will_resize is False. + TORCH_CHECK( + !output_operand.will_resize, "output operand will_resize is True."); + // Get a contiguous tensor with input memory format. + at::Tensor output = at::empty( + output_tensor_base.sizes(), + input_tensor_base.options().memory_format(memory_format)); + // For structured op which inheried from TensorIteratorBase, maybe you need + // to call set_output_raw_strided function to update output stored in op + // sturctured. abs op is no need to do this. + output_operand.exchange_tensor( + c10::MaybeOwned::owned(std::in_place, output)); + abs_function( + (float*)output_operand.tensor_base().mutable_data_ptr(), + (float*)iter.data_ptr(1), + iter.numel()); + // Copy tensor base to original tensor base, and keep same scalar type and + // stride with cpu and gpu. + if (output_operand.original_tensor_base().defined() && + !output_operand.original_tensor_base().is_same( + output_operand.tensor_base())) { + output_operand.original_tensor().copy_(output_operand.tensor()); + output_operand.restore_original_tensor(); + } + } +} + +void quantize_tensor_per_tensor_affine_stub_openreg( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) {} + +} // namespace at::native + +namespace at::native { + +namespace { +struct CustomAutogradFnReturnsSelf + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; +} // namespace + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h new file mode 100644 index 0000000000000..95109cd3fa331 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h @@ -0,0 +1,70 @@ +#include "Common.h" + +namespace at::native { +at::Tensor quantize_per_tensor_openreg( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype); +int64_t _fused_sdp_choice_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa); +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale); +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward_openreg( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale); +} // namespace at::native + +namespace at::native { +void abs_kernel_openreg(at::TensorIteratorBase& iter); +void quantize_tensor_per_tensor_affine_stub_openreg( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point); +} // namespace at::native + +namespace at::native { +at::Tensor custom_autograd_fn_returns_self(at::Tensor x); +at::Tensor custom_autograd_fn_aliasing(at::Tensor x); +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp new file mode 100644 index 0000000000000..973869087a2e2 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp @@ -0,0 +1,173 @@ +#include "Minimal.h" + +namespace at::native { + +at::Tensor empty_memory_format_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_generic( + size, allocator, pu1_dks, dtype, memory_format_opt); +} + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_strided_generic( + size, stride, allocator, pu1_dks, dtype); +} + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + MemoryGuard guard(self); + + return at::cpu::as_strided_symint(self, size, stride, storage_offset); +} + +const at::Tensor& resize_openreg_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor _reshape_alias_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias( + self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride)); +} + +at::Tensor _copy_from_openreg( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + TORCH_CHECK(self.defined(), "Source tensor (self) is not defined."); + TORCH_CHECK(dst.defined(), "Destination tensor (dst) is not defined."); + + MemoryGuard guard(self, dst); + + if (self.device() == dst.device()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + const at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self_as_cpu, non_blocking); + + } else { + if (self.is_cpu()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self, non_blocking); + + } else { + at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst), self_as_cpu, non_blocking); + } + } + + return dst; +} + +at::Tensor _copy_from_and_resize_openreg( + const at::Tensor& self, + const at::Tensor& dst) { + at::native::resize_(dst, self.sizes(), std::nullopt); + + MemoryGuard guard(self, dst); + + return at::native::copy_(const_cast(dst), self, false); +} + +at::Scalar _local_scalar_dense_openreg(const at::Tensor& self) { + MemoryGuard guard(self); + return at::native::_local_scalar_dense_cpu(self); +} + +at::Tensor& set_source_Tensor_openreg_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::set_tensor_(self, source); +} + +at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source) { + return at::native::set_(self, source); +} + +at::Tensor& set_source_Storage_storage_offset_openreg_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + // call native:: + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + +at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size) { + MemoryGuard guard(self); + return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size)); +} + +void cpu_fallback_openreg( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + at::native::cpu_fallback(op, stack); +} + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h new file mode 100644 index 0000000000000..3d144f2debea5 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h @@ -0,0 +1,67 @@ +#include "Common.h" + +namespace at::native { + +at::Tensor empty_memory_format_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset); + +const at::Tensor& resize_openreg_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format); + +at::Tensor _reshape_alias_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride); + +at::Tensor _copy_from_openreg( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking); + +at::Tensor _copy_from_and_resize_openreg( + const at::Tensor& self, + const at::Tensor& dst); + +at::Scalar _local_scalar_dense_openreg(const at::Tensor& self); + +at::Tensor& set_source_Tensor_openreg_( + at::Tensor& self, + const at::Tensor& source); + +at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source); + +at::Tensor& set_source_Storage_storage_offset_openreg_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride); + +at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size); + +void cpu_fallback_openreg( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp new file mode 100644 index 0000000000000..3d35b677cd208 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegDeviceAllocator.h" + +namespace c10::openreg { + +static OpenRegDeviceAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h new file mode 100644 index 0000000000000..c9aea4a913427 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h @@ -0,0 +1,43 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegDeviceAllocator final : at::Allocator { + OpenRegDeviceAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + int current_device_index = -1; + orGetDevice(¤t_device_index); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + void* data = nullptr; + if (nbytes > 0) { + orMalloc(&data, nbytes); + TORCH_CHECK( + data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp new file mode 100644 index 0000000000000..240c2d8ce1aad --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp @@ -0,0 +1,73 @@ +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +orError_t GetDeviceCount(int* dev_count) { + return orGetDeviceCount(dev_count); +} + +orError_t GetDevice(c10::DeviceIndex* device) { + int tmp_device = -1; + auto err = orGetDevice(&tmp_device); + *device = static_cast(tmp_device); + return err; +} + +orError_t SetDevice(c10::DeviceIndex device) { + int cur_device = -1; + orGetDevice(&cur_device); + if (device == cur_device) { + return orSuccess; + } + return orSetDevice(device); +} + +int device_count_impl() { + int count = 0; + GetDeviceCount(&count); + return count; +} + +c10::DeviceIndex device_count() noexcept { + // initialize number of devices only once + static int count = []() { + try { + auto result = device_count_impl(); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many devices, DeviceIndex overflowed"); + return result; + } catch (const c10::Error& ex) { + // We don't want to fail, but still log the warning + // msg() returns the message without the stack trace + TORCH_WARN("Device initialization: ", ex.msg()); + return 0; + } + }(); + return static_cast(count); +} + +c10::DeviceIndex current_device() { + c10::DeviceIndex cur_device = -1; + GetDevice(&cur_device); + return cur_device; +} + +void set_device(c10::DeviceIndex device) { + SetDevice(device); +} + +DeviceIndex ExchangeDevice(DeviceIndex device) { + int current_device = -1; + orGetDevice(¤t_device); + + if (device != current_device) { + orSetDevice(device); + } + + return current_device; +} + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h new file mode 100644 index 0000000000000..b6b991ff6d3a3 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include + +namespace c10::openreg { + +c10::DeviceIndex device_count() noexcept; +DeviceIndex current_device(); +void set_device(c10::DeviceIndex device); + +DeviceIndex ExchangeDevice(DeviceIndex device); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp new file mode 100644 index 0000000000000..c2e03f66adc41 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp @@ -0,0 +1,28 @@ +#include "OpenRegGenerator.h" + +// Default, global generators, one per device. +static std::vector default_generators; + +namespace c10::openreg { + +const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) { + static bool flag [[maybe_unused]] = []() { + auto deivce_nums = device_count(); + default_generators.resize(deivce_nums); + for (auto i = 0; i < deivce_nums; i++) { + default_generators[i] = at::make_generator(i); + default_generators[i].seed(); + } + return true; + }(); + + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < device_count()); + } + return default_generators[idx]; +} + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h new file mode 100644 index 0000000000000..877a9707306fc --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h @@ -0,0 +1,21 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { +class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { + public: + OpenRegGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~OpenRegGeneratorImpl() override = default; +}; + +const at::Generator& getDefaultOpenRegGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp new file mode 100644 index 0000000000000..d50e56e40942d --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp @@ -0,0 +1,7 @@ +#include "OpenRegGuard.h" + +namespace c10::openreg { + +C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h new file mode 100644 index 0000000000000..f0150fe680fb8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -0,0 +1,197 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +// Device guard registration +struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; + + OpenRegGuardImpl() = default; + explicit OpenRegGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == static_type); + } + + /** + * Return the type of device managed by this guard implementation. + */ + c10::DeviceType type() const override { + return static_type; + } + + /** + * Set the current device to Device, and return the previous c10::Device. + */ + c10::Device exchangeDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + auto old_device_index = ExchangeDevice(d.index()); + return c10::Device(static_type, old_device_index); + } + + /** + * Get the current device. + */ + c10::Device getDevice() const override { + int device_index = current_device(); + return c10::Device(static_type, device_index); + } + + /** + * Set the current device to c10::Device. + */ + void setDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Set the current device to c10::Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + void uncheckedSetDevice(c10::Device d) const noexcept override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Get the current stream for a given device. + */ + c10::Stream getStream(c10::Device d) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get the default stream for a given device. + */ + c10::Stream getDefaultStream(c10::Device d) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get a stream from the global pool for a given device. + */ + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + return s; + } + + /** + * Destroys the given event. + */ + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + static int event_id = 1; + + if (!*event) + *event = reinterpret_cast(event_id++); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(void* event, const c10::Stream& stream) const override {} + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool queryEvent(void* event) const override { + return true; + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + c10::DeviceIndex deviceCount() const noexcept override { + int device_index = -1; + orGetDeviceCount(&device_index); + return device_index; + } + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + bool queryStream(const c10::Stream& stream) const override { + return true; + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + void synchronizeStream(const c10::Stream& stream) const override {} + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + void synchronizeEvent(void* event) const override {} + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override {} + + /** + * Fetch the elapsed time between two recorded events. + */ + double elapsedTime( + void* event1, + void* event2, + const c10::DeviceIndex device_index) const override { + return 1; + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp new file mode 100644 index 0000000000000..57bc2d9f0d1bc --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp @@ -0,0 +1,11 @@ +#include "OpenRegHooks.h" + +namespace c10::openreg { + +static bool register_hook_flag [[maybe_unused]] = []() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); + + return true; +}(); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h new file mode 100644 index 0000000000000..656fba8eae484 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h @@ -0,0 +1,41 @@ +#include +#include + +#include +#include + +#include + +#include "OpenRegGenerator.h" + +namespace c10::openreg { +struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { + OpenRegHooksInterface() {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + return true; + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return at::getHostAllocator(at::kPrivateUse1); + } + + bool isPinnedPtr(const void* data) const override { + orPointerAttributes attr{}; + orPointerGetAttributes(&attr, data); + + return attr.type == orMemoryTypeHost; + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + return getDefaultOpenRegGenerator(device_index); + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { + return at::make_generator(device_index); + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp new file mode 100644 index 0000000000000..552638035c386 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegHostAllocator.h" + +namespace c10::openreg { + +OpenRegHostAllocator caching_host_allocator; +REGISTER_HOST_ALLOCATOR(at::kPrivateUse1, &caching_host_allocator); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h new file mode 100644 index 0000000000000..edef545a27835 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h @@ -0,0 +1,48 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegHostAllocator final : at::HostAllocator { + OpenRegHostAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + void* data = nullptr; + if (nbytes > 0) { + orMallocHost(&data, nbytes); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyHostToHost); + } + + // ignore + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return true; + } + void empty_cache() override {} + at::HostStats get_stats() override { + return at::HostStats(); + } + void reset_accumulated_stats() override {} + void reset_peak_stats() override {} +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp new file mode 100644 index 0000000000000..43809d60604f8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp @@ -0,0 +1,48 @@ +#include "OpenRegSerialization.h" + +namespace c10::openreg { +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h new file mode 100644 index 0000000000000..559e92ea82f7b --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h @@ -0,0 +1,10 @@ +#include + +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt new file mode 100644 index 0000000000000..42d5e8d799f4e --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt @@ -0,0 +1,2 @@ +torch +pybind11 diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py new file mode 100644 index 0000000000000..38a866e4ce219 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -0,0 +1,102 @@ +import multiprocessing +import os +import shutil +import subprocess +import sys +import sysconfig +from distutils.command.clean import clean + +from setuptools import Extension, find_packages, setup + + +PACKAGE_NAME = "torch_openreg" +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) + + +def get_pytorch_dir(): + import torch + + return os.path.dirname(os.path.realpath(torch.__file__)) + + +def build_deps(): + build_dir = os.path.join(BASE_DIR, "build") + os.makedirs(build_dir, exist_ok=True) + + cmake_args = [ + "-DCMAKE_INSTALL_PREFIX=" + + os.path.realpath(os.path.join(BASE_DIR, "torch_openreg")), + "-DPYTHON_INCLUDE_DIR=" + sysconfig.get_paths().get("include"), + "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir(), + ] + + subprocess.check_call( + ["cmake", BASE_DIR] + cmake_args, cwd=build_dir, env=os.environ + ) + + build_args = [ + "--build", + ".", + "--target", + "install", + "--", + ] + build_args += ["-j", str(multiprocessing.cpu_count())] + + command = ["cmake"] + build_args + subprocess.check_call(command, cwd=build_dir, env=os.environ) + + +class BuildClean(clean): + def run(self): + for i in ["build", "install", "torch_openreg.egg-info", "torch_openreg/lib"]: + dirs = os.path.join(BASE_DIR, i) + if os.path.exists(dirs) and os.path.isdir(dirs): + shutil.rmtree(dirs) + + for dirpath, _, filenames in os.walk(os.path.join(BASE_DIR, "torch_openreg")): + for filename in filenames: + if filename.endswith(".so"): + os.remove(os.path.join(dirpath, filename)) + + +RUN_BUILD_DEPS = any(arg == "clean" for arg in sys.argv) + + +def main(): + if not RUN_BUILD_DEPS: + build_deps() + + ext_modules = [ + Extension( + name="torch_openreg._C", + sources=["torch_openreg/csrc/stub.c"], + extra_compile_args=["-g", "-Wall", "-Werror"], + libraries=["torch_bindings"], + library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], + extra_link_args=["-Wl,-rpath,$ORIGIN/lib"], + ) + ] + + package_data = {PACKAGE_NAME: ["lib/*.so*"]} + + setup( + name=PACKAGE_NAME, + version="0.0.1", + author="PyTorch Core Team", + description="Example for PyTorch out of tree registration", + packages=find_packages(exclude=("test",)), + package_data=package_data, + install_requires=[ + "torch", + ], + ext_modules=ext_modules, + python_requires=">=3.8", + cmdclass={ + "clean": BuildClean, # type: ignore[misc] + }, + ) + + +if __name__ == "__main__": + main() diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt new file mode 100644 index 0000000000000..7fec109eeb1cd --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LIBRARY_NAME openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md new file mode 100644 index 0000000000000..af17ef3abdb1a --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md @@ -0,0 +1,137 @@ +# OpenReg: An Accelerator Backend that Simulates CUDA Behavior on a CPU + +## Introduction + +OpenReg is a C++ backend library that simulates the behavior of a CUDA-like device on a CPU. Its core objective is **not to accelerate computation or improve performance**, but rather to **simulate modern CUDA programming, enabling developers to prototype and test in an environment without actual GPU hardware**. The current design principles are as follows: + +* **API Consistency**: Provide an interface consistent with the CUDA Runtime API, allowing upper-level applications (like PyTorch's PrivateUse1 backend) to switch and test seamlessly. +* **Functional Consistency**: Provide behavior consistent with the CUDA Runtime, such as memory isolation, device context management, etc. +* **Completeness**: Aim to support PrivateUse1 device integration and safeguard the third-party device integration mechanism, without striving to cover all capabilities of the CUDA Runtime. + +## Directory Structure + +The project's code is organized with a clear structure and separation of responsibilities: + +```text +openreg/ +├── CMakeLists.txt # Top-level CMake build script, used to compile and generate libopenreg.so +├── include/ +│ └── openreg.h # Public API header file, external users only need to include this file +└── csrc/ + ├── device.cpp # Implementation of device management-related APIs + └── memory.cpp # Implementation of APIs for memory management, copying, and protection +``` + +* `include/openreg.h`: Defines all externally exposed C-style APIs, data structures, and enums. It is the "public face" of this library. +* `csrc/`: Contains the C++ implementation source code for all core functionalities. + * `device.cpp`: Implements device discovery (`orGetDeviceCount`) and thread context management (`orSetDevice`/`orGetDevice`). + * `memory.cpp`: Implements the core functions of memory allocation (`orMalloc`/`orMallocHost`), deallocation, copying, and memory protection (`orMemoryProtect`, `orMemoryUnprotect`). +* `CMakeLists.txt`: Responsible for compiling and linking all source files under the `csrc/` directory to generate the final `libopenreg.so` shared library. + +## Implemented APIs + +OpenReg currently provides a set of APIs covering basic memory and device management. + +### Device Management APIs + +| OpenReg | CUDA | Feature Description | +| :------------------- | :------------------- | :------------------------------------------------ | +| `orGetDeviceCount` | `cudaGetDeviceCount` | Get the number of devices | +| `orSetDevice` | `cudaSetDevice` | Set the current device for the current thread | +| `orGetDevice` | `cudaGetDevice` | Get the current device for the current thread | + +### Memory Management APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :--------------------------- | :----------------------------------------- | +| `orMalloc` | `cudaMalloc` | Allocate device memory | +| `orFree` | `cudaFree` | Free device memory | +| `orMallocHost` | `cudaMallocHost` | Allocate page-locked (Pinned) host memory | +| `orFreeHost` | `cudaFreeHost` | Free page-locked host memory | +| `orMemcpy` | `cudaMemcpy` | Synchronous memory copy | +| `orMemcpyAsync` | `cudaMemcpyAsync` | Asynchronous memory copy | +| `orPointerGetAttributes` | `cudaPointerGetAttributes` | Get pointer attributes | +| `orMemoryUnprotect` | - | (Internal use) Unprotect memory | +| `orMemoryProtect` | - | (Internal use) Restore memory protection | + +## Implementation Principles + +### Device Management Principles + +Simulating multiple devices and thread-safe device context switching: + +1. **Device Count**: The total number of simulated devices is defined by the compile-time constant `constexpr int kDeviceCount`. +2. **Device Switching**: Device switching in multi-threaded scenarios is simulated using a **TLS (Thread-Local Storage) global variable**. + +### Memory Management Principles + +Simulating device memory, host memory, and memory copies: + +1. **Allocation**: A page-aligned memory block is allocated using `mmap` + `mprotect` with the permission flag `PROT_NONE`. Read, write, and execute operations on this memory region are all prohibited. +2. **Deallocation**: Memory is freed using `munmap`. +3. **Authorization**: When a legitimate memory access is required, an RAII guard restores the memory permissions to `PROT_READ | PROT_WRITE`. The permissions are automatically reverted to `PROT_NONE` when the scope is exited. + +## Usage Example + +The following is a simple code snippet demonstrating how to use the core features of the OpenReg library. + +```cpp +#include "openreg.h" +#include +#include +#include + +#define OR_CHECK(call) do { \ + orError_t err = call; \ + if (err != orSuccess) { \ + fprintf(stderr, "OR Error code %d in %s at line %d\n", err, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } \ +} while (0) + +int main() { + int device_count = 0; + OR_CHECK(orGetDeviceCount(&device_count)); + std::cout << "Found " << device_count << " simulated devices." << std::endl; + + int current_device = -1; + OR_CHECK(orSetDevice(1)); + OR_CHECK(orGetDevice(¤t_device)); + std::cout << "Set current device to " << current_device << "." << std::endl; + + const int n = 1024; + const size_t size = n * sizeof(int); + int *h_a, *d_a; + OR_CHECK(orMallocHost((void**)&h_a, size)); + OR_CHECK(orMalloc((void**)&d_a, size)); + + orPointerAttributes attr; + OR_CHECK(orPointerGetAttributes(&attr, d_a)); + std::cout << "Pointer " << (void*)d_a << " is of type " << attr.type + << " on device " << attr.device << std::endl; + + for (int i = 0; i < n; ++i) { + h_a[i] = i; + } + OR_CHECK(orMemcpy(d_a, h_a, size, orMemcpyHostToDevice)); + std::cout << "Data copied from Host to Device." << std::endl; + + // std::cout << "Trying to access device memory directly from CPU..." << std::endl; + // int val = d_a[0]; // CRASH! + + // Clean up resources + OR_CHECK(orFree(d_a)); + OR_CHECK(orFreeHost(h_a)); + std::cout << "Resources freed." << std::endl; + + return 0; +} +``` + +## Next Steps + +To better support PrivateUse1 device integration, the following capabilities are planned for the future: + +* **Stream Support**: Provide the ability to simulate CUDA Streams. +* **Event Support**: Provide the ability to simulate CUDA Events. +* **Cross-Platform Support**: Add support for Windows and macOS (low priority). diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp new file mode 100644 index 0000000000000..3f1d43ea0b554 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp @@ -0,0 +1,35 @@ +#include + +namespace { +// Total device numbers +constexpr int DEVICE_COUNT = 2; +// Current device index +thread_local int gCurrentDevice = 0; +} // namespace + +orError_t orGetDeviceCount(int* count) { + if (!count) { + return orErrorUnknown; + } + + *count = DEVICE_COUNT; + return orSuccess; +} + +orError_t orGetDevice(int* device) { + if (!device) { + return orErrorUnknown; + } + + *device = gCurrentDevice; + return orSuccess; +} + +orError_t orSetDevice(int device) { + if (device < 0 || device >= DEVICE_COUNT) { + return orErrorUnknown; + } + + gCurrentDevice = device; + return orSuccess; +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp new file mode 100644 index 0000000000000..762cd96d23bb8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp @@ -0,0 +1,249 @@ +#include + +#include +#include +#include +#include +#include +#include + +namespace openreg { +namespace internal { + +class ScopedMemoryProtector { + public: + ScopedMemoryProtector(const orPointerAttributes& info) + : m_info(info), m_protected(false) { + if (m_info.type == orMemoryType::orMemoryTypeDevice) { + if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) == + 0) { + m_protected = true; + } + } + } + ~ScopedMemoryProtector() { + if (m_protected) { + mprotect(m_info.pointer, m_info.size, PROT_NONE); + } + } + ScopedMemoryProtector(const ScopedMemoryProtector&) = delete; + ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete; + + private: + orPointerAttributes m_info; + bool m_protected; +}; + +class MemoryManager { + public: + static MemoryManager& getInstance() { + static MemoryManager instance; + return instance; + } + + orError_t allocate(void** ptr, size_t size, orMemoryType type) { + if (!ptr || size == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + long page_size = sysconf(_SC_PAGESIZE); + size_t aligned_size = ((size - 1) / page_size + 1) * page_size; + void* mem = nullptr; + int current_device = -1; + + if (type == orMemoryType::orMemoryTypeDevice) { + orGetDevice(¤t_device); + + mem = mmap( + nullptr, + aligned_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + if (mem == MAP_FAILED) + return orErrorUnknown; + if (mprotect(mem, aligned_size, PROT_NONE) != 0) { + munmap(mem, aligned_size); + return orErrorUnknown; + } + } else { + if (posix_memalign(&mem, page_size, aligned_size) != 0) { + return orErrorUnknown; + } + } + + m_registry[mem] = {type, current_device, mem, aligned_size}; + *ptr = mem; + return orSuccess; + } + + orError_t free(void* ptr) { + if (!ptr) + return orSuccess; + + std::lock_guard lock(m_mutex); + auto it = m_registry.find(ptr); + if (it == m_registry.end()) + return orErrorUnknown; + const auto& info = it->second; + if (info.type == orMemoryType::orMemoryTypeDevice) { + mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE); + munmap(info.pointer, info.size); + } else { + ::free(info.pointer); + } + m_registry.erase(it); + return orSuccess; + } + + orError_t memcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + if (!dst || !src || count == 0) + return orErrorUnknown; + std::lock_guard lock(m_mutex); + orPointerAttributes dst_info = getPointerInfo(dst); + orPointerAttributes src_info = getPointerInfo(src); + switch (kind) { + case orMemcpyHostToDevice: + if (dst_info.type != orMemoryType::orMemoryTypeDevice || + src_info.type == orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyDeviceToHost: + if (dst_info.type == orMemoryType::orMemoryTypeDevice || + src_info.type != orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyDeviceToDevice: + if (dst_info.type != orMemoryType::orMemoryTypeDevice || + src_info.type != orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyHostToHost: + if (dst_info.type == orMemoryType::orMemoryTypeDevice || + src_info.type == orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + } + { + ScopedMemoryProtector dst_protector(dst_info); + ScopedMemoryProtector src_protector(src_info); + ::memcpy(dst, src, count); + } + + return orSuccess; + } + + orError_t getPointerAttributes( + orPointerAttributes* attributes, + const void* ptr) { + if (!attributes || !ptr) + return orErrorUnknown; + + std ::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + + attributes->type = info.type; + if (info.type == orMemoryType::orMemoryTypeUnmanaged) { + attributes->device = -1; + attributes->pointer = const_cast(ptr); + attributes->size = 0; + } else { + attributes->device = info.device; + attributes->pointer = info.pointer; + attributes->size = info.size; + } + + return orSuccess; + } + + orError_t unprotect(void* ptr) { + std::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + if (info.type != orMemoryType::orMemoryTypeDevice) { + return orErrorUnknown; + } + if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) { + return orErrorUnknown; + } + return orSuccess; + } + + orError_t protect(void* ptr) { + std::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + if (info.type != orMemoryType::orMemoryTypeDevice) { + return orErrorUnknown; + } + if (mprotect(info.pointer, info.size, PROT_NONE) != 0) { + return orErrorUnknown; + } + return orSuccess; + } + + private: + MemoryManager() = default; + orPointerAttributes getPointerInfo(const void* ptr) { + auto it = m_registry.upper_bound(const_cast(ptr)); + if (it == m_registry.begin()) + return {}; + --it; + const char* p_char = static_cast(ptr); + const char* base_char = static_cast(it->first); + if (p_char >= base_char && p_char < (base_char + it->second.size)) { + return it->second; + } + return {}; + } + std::map m_registry; + std::mutex m_mutex; +}; + +} // namespace internal +} // namespace openreg + +orError_t orMalloc(void** devPtr, size_t size) { + return openreg::internal::MemoryManager::getInstance().allocate( + devPtr, size, orMemoryType::orMemoryTypeDevice); +} + +orError_t orFree(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().free(devPtr); +} + +orError_t orMallocHost(void** hostPtr, size_t size) { + return openreg::internal::MemoryManager::getInstance().allocate( + hostPtr, size, orMemoryType::orMemoryTypeHost); +} + +orError_t orFreeHost(void* hostPtr) { + return openreg::internal::MemoryManager::getInstance().free(hostPtr); +} + +orError_t orMemcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + return openreg::internal::MemoryManager::getInstance().memcpy( + dst, src, count, kind); +} + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr) { + return openreg::internal::MemoryManager::getInstance().getPointerAttributes( + attributes, ptr); +} + +orError_t orMemoryUnprotect(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().unprotect(devPtr); +} + +orError_t orMemoryProtect(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().protect(devPtr); +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h new file mode 100644 index 0000000000000..b6b0b3da4295c --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum orError_t { orSuccess = 0, orErrorUnknown = 1 } orError_t; + +typedef enum orMemcpyKind { + orMemcpyHostToHost = 0, + orMemcpyHostToDevice = 1, + orMemcpyDeviceToHost = 2, + orMemcpyDeviceToDevice = 3 +} orMemcpyKind; + +typedef enum orMemoryType { + orMemoryTypeUnmanaged = 0, + orMemoryTypeHost = 1, + orMemoryTypeDevice = 2 +} orMemoryType; + +struct orPointerAttributes { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device; + void* pointer; + size_t size; +}; + +orError_t orMalloc(void** devPtr, size_t size); +orError_t orFree(void* devPtr); +orError_t orMallocHost(void** hostPtr, size_t size); +orError_t orFreeHost(void* hostPtr); +orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); +orError_t orMemoryUnprotect(void* devPtr); +orError_t orMemoryProtect(void* devPtr); + +orError_t orGetDeviceCount(int* count); +orError_t orSetDevice(int device); +orError_t orGetDevice(int* device); + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py new file mode 100644 index 0000000000000..32bb170075efb --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py @@ -0,0 +1,8 @@ +import torch +import torch_openreg._C # type: ignore[misc] +import torch_openreg.openreg + + +torch.utils.rename_privateuse1_backend("openreg") +torch._register_device_module("openreg", torch_openreg.openreg) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 0000000000000..574b5b1c748a3 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LIBRARY_NAME torch_bindings) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg) +target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp new file mode 100644 index 0000000000000..4acdbfc8e1dce --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp @@ -0,0 +1,99 @@ +#include + +#include +#include +#include +#include +#include + +#include + +static PyObject* _initExtension(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "_get_default_generator expects an int, but got ", + THPUtils_typename(arg)); + auto idx = static_cast(THPUtils_unpackLong(arg)); + + return THPGenerator_initDefaultGenerator( + at::globalContext().defaultGenerator( + c10::Device(c10::DeviceType::PrivateUse1, idx))); + + END_HANDLE_TH_ERRORS +} + +PyObject* _setDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); + auto device = THPUtils_unpackLong(arg); + + torch::utils::device_lazy_init(at::kPrivateUse1); + c10::openreg::set_device(static_cast(device)); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _exchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::openreg::ExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDevice(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + auto device = static_cast(c10::openreg::current_device()); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(c10::openreg::device_count()); + END_HANDLE_TH_ERRORS +} + +static PyMethodDef methods[] = { + {"_init", _initExtension, METH_NOARGS, nullptr}, + {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, + {"_get_device", _getDevice, METH_NOARGS, nullptr}, + {"_set_device", _setDevice, METH_O, nullptr}, + {"_exchangeDevice", _exchangeDevice, METH_O, nullptr}, + {"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +/* + * When ASAN is enabled, PyTorch modifies the dlopen flag during import, + * causing all global and weak symbols in _C.so and its dependent libraries + * to be exposed to the global symbol scope, which in turn causes + * subsequent symbols with the same name in other libraries to be intercepted. + * Therefore, it cannot be named initModule here, otherwise initModule + * in torch/csrc/Module.cpp will be called, resulting in failure. + */ +extern "C" PyObject* initOpenRegModule(void) { + static struct PyModuleDef openreg_C_module = { + PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; + PyObject* mod = PyModule_Create(&openreg_C_module); + + return mod; +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c new file mode 100644 index 0000000000000..cd3eb4fe1ecc3 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c @@ -0,0 +1,15 @@ +#include + +extern PyObject* initOpenRegModule(void); + +#ifndef _WIN32 +#ifdef __cplusplus +extern "C" +#endif +__attribute__((visibility("default"))) PyObject* PyInit__C(void); +#endif + +PyMODINIT_FUNC PyInit__C(void) +{ + return initOpenRegModule(); +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py new file mode 100644 index 0000000000000..177468b8f41bd --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py @@ -0,0 +1,72 @@ +import torch +import torch_openreg._C # type: ignore[misc] + + +_initialized = False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device): + self.idx = torch.accelerator._get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) + + def __exit__(self, type, value, traceback): + self.idx = torch_openreg._C._set_device(self.prev_idx) + return False + + +def is_available(): + return True + + +def device_count() -> int: + return torch_openreg._C._get_device_count() + + +def current_device(): + return torch_openreg._C._get_device() + + +def set_device(device) -> None: + return torch_openreg._C._set_device(device) + + +def is_initialized(): + return _initialized + + +def _lazy_init(): + global _initialized + if is_initialized(): + return + torch_openreg._C._init() + _initialized = True + + +from .random import * # noqa: F403 + + +__all__ = [ + "device", + "device_count", + "current_device", + "set_device", + "initial_seed", + "is_available", + "is_initialized", + "random", + "manual_seed", + "manual_seed_all", + "get_rng_state", + "set_rng_state", +] diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py new file mode 100644 index 0000000000000..5202145a55525 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py @@ -0,0 +1,60 @@ +import torch +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init, current_device, device_count + + +__all__ = [ + "get_rng_state", + "set_rng_state", + "manual_seed", + "manual_seed_all", + "initial_seed", +] + + +def get_rng_state(device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.get_state() + + +def set_rng_state(new_state, device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.set_state(new_state) + + +def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + +def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + +def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) diff --git a/test/run_test.py b/test/run_test.py index 7f810c039c7f1..26b10ac4ac61c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -28,6 +28,7 @@ from torch.testing._internal.common_utils import ( get_report_path, IS_CI, + IS_LINUX, IS_MACOS, retry_shell, set_cwd, @@ -913,8 +914,12 @@ def _test_autoload(test_directory, options, enable=True): def run_test_with_openreg(test_module, test_directory, options): + # TODO(FFFrog): Will remove this later when windows/macos are supported. + if not IS_LINUX: + return 0 + openreg_dir = os.path.join( - test_directory, "cpp_extensions", "open_registration_extension" + test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg" ) install_dir, return_code = install_cpp_extensions(openreg_dir) if return_code != 0: diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index dc44f66bcebc9..d6d7ad3dc4678 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -3,7 +3,7 @@ import os import unittest -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch import torch.testing._internal.common_utils as common diff --git a/test/test_openreg.py b/test/test_openreg.py index 1fab8c4261c7d..dc52231ff7bfe 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -10,7 +10,7 @@ import numpy as np import psutil -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch from torch.serialization import safe_globals @@ -285,7 +285,6 @@ def test_manual_seed(self): self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] # Autograd - @unittest.skipIf(not IS_LINUX, "Only works on linux") def test_autograd_init(self): # Make sure autograd is initialized torch.ones(2, requires_grad=True, device="openreg").sum().backward() @@ -584,4 +583,5 @@ def test_open_device_dlpack(self): if __name__ == "__main__": - run_tests() + if IS_LINUX: + run_tests() diff --git a/test/test_transformers_privateuse1.py b/test/test_transformers_privateuse1.py index 728b0a1188252..0aa15260d0949 100644 --- a/test/test_transformers_privateuse1.py +++ b/test/test_transformers_privateuse1.py @@ -4,7 +4,7 @@ from collections import namedtuple from functools import partial -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch from torch.nn.attention import SDPBackend From 7f9fc7e67ce9853a1bb4b16c901c708f78c1c5cd Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Tue, 15 Jul 2025 10:07:25 +0000 Subject: [PATCH 042/457] [Inductor] Add CPU_MAX_FIRST_DIMENSION_DECOMPOSITION and CPU_MAX_OTHER_DIMENSION_DECOMPOSITION for decompose_mm_pass (#158183) Differential Revision: D78209993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158183 Approved by: https://github.com/houseroad --- .../fx_passes/decompose_mem_bound_mm.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index e6757c3ad9e31..30cfcdd615fbe 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -15,11 +15,17 @@ log = logging.getLogger(__name__) # TODO: need a better strategy for decomposing mm +# The following two constants are for CUDA device only MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 MAX_OTHER_DIMENSION_DECOMPOSITION = 32 +# The following two constants are for CPU device only +CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1 +CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048 min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION +cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION if "decompose_mm_pass" in config.post_grad_fusion_options: min_first_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" @@ -27,6 +33,16 @@ max_other_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION + ) + cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION + ) def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: @@ -57,7 +73,10 @@ def should_decompose_bmm(mat1, mat2) -> bool: return False return True elif check_device(mat1, mat2, device="cpu"): - if mat1.shape[0] == 1 and mat2.shape[0] == 1: + if ( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + and mat2.shape[0] <= cpu_max_first_dimension_decomposition + ): return True return False @@ -77,9 +96,15 @@ def should_decompose_mm(mat1, mat2) -> bool: and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) ) or ( check_device(mat1, mat2, device="cpu") - and statically_known_true(mat1.shape[0] == 1) - and statically_known_true(mat2.shape[0] <= 128) - and statically_known_true(mat2.shape[1] <= 512) + and statically_known_true( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + ) + and statically_known_true( + mat2.shape[0] <= cpu_max_other_dimension_decomposition + ) + and statically_known_true( + mat2.shape[1] <= cpu_max_other_dimension_decomposition + ) ) From e241a07e6b88aa49d604803bc5a6562f0d9f94d2 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 15 Jul 2025 10:11:46 +0000 Subject: [PATCH 043/457] Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312) # Motivation Refactor `CUDAAllocatorConfig` to reuse `AcceleratorAllocatorConfig` and `ConfigTokenizer`. We would deprecate those option that overleap with `AcceleratorAllocatorConfig` in the following PR and keep them only for BC. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150312 Approved by: https://github.com/albanD --- c10/cuda/CUDAAllocatorConfig.cpp | 469 ++++++------------------------ c10/cuda/CUDAAllocatorConfig.h | 130 +++++---- c10/cuda/CUDACachingAllocator.cpp | 50 +--- c10/cuda/CUDACachingAllocator.h | 4 +- 4 files changed, 158 insertions(+), 495 deletions(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index d2efb8c593e44..49fa2e1e95ed3 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -1,389 +1,119 @@ #include -#include -#include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif -namespace c10::cuda::CUDACachingAllocator { - -constexpr size_t kRoundUpPowerOfTwoIntervals = 16; - -CUDAAllocatorConfig::CUDAAllocatorConfig() - : m_max_split_size(std::numeric_limits::max()), - m_max_non_split_rounding_size(kLargeBuffer), - m_garbage_collection_threshold(0), - m_pinned_num_register_threads(1), - m_expandable_segments(false), -#if CUDA_VERSION >= 12030 - m_expandable_segments_handle_type( - Expandable_Segments_Handle_Type::UNSPECIFIED), -#else - m_expandable_segments_handle_type( - Expandable_Segments_Handle_Type::POSIX_FD), -#endif - m_release_lock_on_cudamalloc(false), - m_pinned_use_cuda_host_register(false), - m_pinned_use_background_threads(false) { - m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); -} - -size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) { - size_t log_size = (63 - llvm::countLeadingZeros(size)); - - // Our intervals start at 1MB and end at 64GB - const size_t interval_start = - 63 - llvm::countLeadingZeros(static_cast(1048576)); - const size_t interval_end = - 63 - llvm::countLeadingZeros(static_cast(68719476736)); - TORCH_CHECK( - (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), - "kRoundUpPowerOfTwoIntervals mismatch"); - - int index = static_cast(log_size) - static_cast(interval_start); - - index = std::max(0, index); - index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); - return instance().m_roundup_power2_divisions[index]; -} - -void CUDAAllocatorConfig::lexArgs( - const std::string& env, - std::vector& config) { - std::vector buf; - - for (char ch : env) { - if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { - if (!buf.empty()) { - config.emplace_back(buf.begin(), buf.end()); - buf.clear(); - } - config.emplace_back(1, ch); - } else if (ch != ' ') { - buf.emplace_back(ch); - } - } - if (!buf.empty()) { - config.emplace_back(buf.begin(), buf.end()); - } -} - -void CUDAAllocatorConfig::consumeToken( - const std::vector& config, - size_t i, - const char c) { - TORCH_CHECK( - i < config.size() && config[i] == std::string(1, c), - "Error parsing CachingAllocator settings, expected ", - c, - ""); -} - -size_t CUDAAllocatorConfig::parseMaxSplitSize( - const std::vector& config, - size_t i) { - consumeToken(config, ++i, ':'); - constexpr int mb = 1024 * 1024; - if (++i < config.size()) { - size_t val1 = stoi(config[i]); - TORCH_CHECK( - val1 > kLargeBuffer / mb, - "CachingAllocator option max_split_size_mb too small, must be > ", - kLargeBuffer / mb, - ""); - val1 = std::max(val1, kLargeBuffer / mb); - val1 = std::min(val1, (std::numeric_limits::max() / mb)); - m_max_split_size = val1 * 1024 * 1024; - } else { - TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); - } - return i; -} - -size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( - const std::vector& config, - size_t i) { - consumeToken(config, ++i, ':'); - constexpr int mb = 1024 * 1024; - if (++i < config.size()) { - size_t val1 = stoi(config[i]); - TORCH_CHECK( - val1 > kLargeBuffer / mb, - "CachingAllocator option max_non_split_rounding_mb too small, must be > ", - kLargeBuffer / mb, - ""); - val1 = std::max(val1, kLargeBuffer / mb); - val1 = std::min(val1, (std::numeric_limits::max() / mb)); - m_max_non_split_rounding_size = val1 * 1024 * 1024; - } else { - TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); - } - return i; -} - -size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( - const std::vector& config, - size_t i) { - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - double val1 = stod(config[i]); - TORCH_CHECK( - val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); - TORCH_CHECK( - val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); - m_garbage_collection_threshold = val1; - } else { - TORCH_CHECK( - false, "Error, expecting garbage_collection_threshold value", ""); - } - return i; -} - -size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( - const std::vector& config, - size_t i) { - consumeToken(config, ++i, ':'); - bool first_value = true; - - if (++i < config.size()) { - if (std::string_view(config[i]) == "[") { - size_t last_index = 0; - // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) - while (++i < config.size() && std::string_view(config[i]) != "]") { - const std::string& val1 = config[i]; - size_t val2 = 0; - - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - val2 = stoi(config[i]); - } else { - TORCH_CHECK( - false, "Error parsing roundup_power2_divisions value", ""); - } - TORCH_CHECK( - val2 == 0 || llvm::isPowerOf2_64(val2), - "For roundups, the divisions has to be power of 2 or 0 to disable roundup ", - ""); +#include - if (std::string_view(val1) == ">") { - std::fill( - std::next( - m_roundup_power2_divisions.begin(), - static_cast::difference_type>( - last_index)), - m_roundup_power2_divisions.end(), - val2); - } else { - size_t val1_long = stoul(val1); - TORCH_CHECK( - llvm::isPowerOf2_64(val1_long), - "For roundups, the intervals have to be power of 2 ", - ""); - - size_t index = 63 - llvm::countLeadingZeros(val1_long); - index = std::max((size_t)0, index); - index = std::min(index, m_roundup_power2_divisions.size() - 1); - - if (first_value) { - std::fill( - m_roundup_power2_divisions.begin(), - std::next( - m_roundup_power2_divisions.begin(), - static_cast::difference_type>( - index)), - val2); - first_value = false; - } - if (index < m_roundup_power2_divisions.size()) { - m_roundup_power2_divisions[index] = val2; - } - last_index = index; - } - - if (std::string_view(config[i + 1]) != "]") { - consumeToken(config, ++i, ','); - } - } - } else { // Keep this for backwards compatibility - size_t val1 = stoi(config[i]); - TORCH_CHECK( - llvm::isPowerOf2_64(val1), - "For roundups, the divisions has to be power of 2 ", - ""); - std::fill( - m_roundup_power2_divisions.begin(), - m_roundup_power2_divisions.end(), - val1); - } - } else { - TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); - } - return i; -} +namespace c10::cuda::CUDACachingAllocator { size_t CUDAAllocatorConfig::parseAllocatorConfig( - const std::vector& config, - size_t i, - bool& used_cudaMallocAsync) { + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { // For ease of maintenance and understanding, the CUDA and ROCm // implementations of this function are separated. This avoids having many // #ifdef's throughout. -#ifdef USE_ROCM // Ease burden on ROCm users by allowing either cuda or hip tokens. // cuda token is broken up to prevent hipify matching it. #define PYTORCH_TOKEN1 \ "cud" \ "aMallocAsync" #define PYTORCH_TOKEN2 "hipMallocAsync" - consumeToken(config, ++i, ':'); - if (++i < config.size()) { + tokenizer.checkToken(++i, ":"); + i++; // Move to the value after the colon + TORCH_CHECK( + ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) || + (tokenizer[i] == PYTORCH_TOKEN2)), + "Unknown allocator backend, " + "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); + if (m_is_allocator_loaded) { + bool aync_allocator_at_runtime = (tokenizer[i] != "native"); TORCH_CHECK( - ((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) || - (config[i] == PYTORCH_TOKEN2)), - "Unknown allocator backend, " - "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); - used_cudaMallocAsync = - (config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2); - TORCH_INTERNAL_ASSERT( - config[i] == get()->name() || - (config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2), - "Allocator backend parsed at runtime != " - "allocator backend parsed at load time, ", - config[i], + aync_allocator_at_runtime == m_use_async_allocator, + "Allocator async backend parsed at runtime != allocator async backend parsed at load time, ", + aync_allocator_at_runtime, " != ", - get()->name()); - } else { - TORCH_CHECK(false, "Error parsing backend value", ""); + m_use_async_allocator); } - return i; -#undef PYTORCH_TOKEN1 -#undef PYTORCH_TOKEN2 -#else // USE_ROCM - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - TORCH_CHECK( - ((config[i] == "native") || (config[i] == "cudaMallocAsync")), - "Unknown allocator backend, " - "options are native and cudaMallocAsync"); - used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); - if (used_cudaMallocAsync) { + m_use_async_allocator = + (tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2); + // CUDA allocator is always loaded at the start of the program + m_is_allocator_loaded = true; + +#if defined(CUDA_VERSION) + if (m_use_async_allocator) { #if CUDA_VERSION >= 11040 - int version = 0; - C10_CUDA_CHECK(cudaDriverGetVersion(&version)); - TORCH_CHECK( - version >= 11040, - "backend:cudaMallocAsync requires CUDA runtime " - "11.4 or newer, but cudaDriverGetVersion returned ", - version); + int version = 0; + C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + TORCH_CHECK( + version >= 11040, + "backend:cudaMallocAsync requires CUDA runtime " + "11.4 or newer, but cudaDriverGetVersion returned ", + version); #else - TORCH_CHECK( - false, - "backend:cudaMallocAsync requires PyTorch to be built with " - "CUDA 11.4 or newer, but CUDA_VERSION is ", - CUDA_VERSION); + TORCH_CHECK( + false, + "backend:cudaMallocAsync requires PyTorch to be built with " + "CUDA 11.4 or newer, but CUDA_VERSION is ", + CUDA_VERSION); #endif - } - TORCH_INTERNAL_ASSERT( - config[i] == get()->name(), - "Allocator backend parsed at runtime != " - "allocator backend parsed at load time"); - } else { - TORCH_CHECK(false, "Error parsing backend value", ""); } +#endif + return i; -#endif // USE_ROCM +#undef PYTORCH_TOKEN1 +#undef PYTORCH_TOKEN2 } -void CUDAAllocatorConfig::parseArgs(const std::optional& env) { +void CUDAAllocatorConfig::parseArgs(const std::string& env) { // If empty, set the default values - m_max_split_size = std::numeric_limits::max(); - m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); - m_garbage_collection_threshold = 0; - bool used_cudaMallocAsync = false; bool used_native_specific_option = false; - if (!env.has_value()) { - return; - } - { - std::lock_guard lock(m_last_allocator_settings_mutex); - m_last_allocator_settings = env.value(); - } - - std::vector config; - lexArgs(env.value(), config); - - for (size_t i = 0; i < config.size(); i++) { - std::string_view config_item_view(config[i]); - if (config_item_view == "max_split_size_mb") { - i = parseMaxSplitSize(config, i); - used_native_specific_option = true; - } else if (config_item_view == "max_non_split_rounding_mb") { - i = parseMaxNonSplitRoundingSize(config, i); - used_native_specific_option = true; - } else if (config_item_view == "garbage_collection_threshold") { - i = parseGarbageCollectionThreshold(config, i); - used_native_specific_option = true; - } else if (config_item_view == "roundup_power2_divisions") { - i = parseRoundUpPower2Divisions(config, i); - used_native_specific_option = true; - } else if (config_item_view == "backend") { - i = parseAllocatorConfig(config, i, used_cudaMallocAsync); - } else if (config_item_view == "expandable_segments") { - used_native_specific_option = true; - consumeToken(config, ++i, ':'); - ++i; - TORCH_CHECK( - i < config.size() && - (std::string_view(config[i]) == "True" || - std::string_view(config[i]) == "False"), - "Expected a single True/False argument for expandable_segments"); - config_item_view = config[i]; - m_expandable_segments = (config_item_view == "True"); + c10::CachingAllocator::ConfigTokenizer tokenizer(env); + for (size_t i = 0; i < tokenizer.size(); i++) { + const auto& key = tokenizer[i]; + if (key == "backend") { + i = parseAllocatorConfig(tokenizer, i); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - config_item_view == "release_lock_on_hipmalloc" || - config_item_view == + key == "release_lock_on_hipmalloc" || + key == "release_lock_on_c" "udamalloc") { used_native_specific_option = true; - consumeToken(config, ++i, ':'); - ++i; - TORCH_CHECK( - i < config.size() && - (std::string_view(config[i]) == "True" || - std::string_view(config[i]) == "False"), - "Expected a single True/False argument for release_lock_on_cudamalloc"); - config_item_view = config[i]; - m_release_lock_on_cudamalloc = (config_item_view == "True"); + tokenizer.checkToken(++i, ":"); + m_release_lock_on_cudamalloc = tokenizer.toBool(++i); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - config_item_view == "pinned_use_hip_host_register" || - config_item_view == + key == "pinned_use_hip_host_register" || + key == "pinned_use_c" "uda_host_register") { - i = parsePinnedUseCudaHostRegister(config, i); + i = parsePinnedUseCudaHostRegister(tokenizer, i); used_native_specific_option = true; - } else if (config_item_view == "pinned_num_register_threads") { - i = parsePinnedNumRegisterThreads(config, i); - used_native_specific_option = true; - } else if (config_item_view == "pinned_use_background_threads") { - i = parsePinnedUseBackgroundThreads(config, i); + } else if (key == "pinned_num_register_threads") { + i = parsePinnedNumRegisterThreads(tokenizer, i); used_native_specific_option = true; } else { + const auto& keys = + c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); TORCH_CHECK( - false, "Unrecognized CachingAllocator option: ", config_item_view); + keys.find(key) != keys.end(), + "Unrecognized key '", + key, + "' in Accelerator allocator config."); + i = tokenizer.skipKey(i); } - if (i + 1 < config.size()) { - consumeToken(config, ++i, ','); + if (i + 1 < tokenizer.size()) { + tokenizer.checkToken(++i, ","); } } - if (used_cudaMallocAsync && used_native_specific_option) { + if (m_use_async_allocator && used_native_specific_option) { TORCH_WARN( "backend:cudaMallocAsync ignores max_split_size_mb," "roundup_power2_divisions, and garbage_collect_threshold."); @@ -391,64 +121,33 @@ void CUDAAllocatorConfig::parseArgs(const std::optional& env) { } size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( - const std::vector& config, + const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i) { - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - TORCH_CHECK( - (config[i] == "True" || config[i] == "False"), - "Expected a single True/False argument for pinned_use_cuda_host_register"); - m_pinned_use_cuda_host_register = (config[i] == "True"); - } else { - TORCH_CHECK( - false, "Error, expecting pinned_use_cuda_host_register value", ""); - } - return i; -} + tokenizer.checkToken(++i, ":"); + m_pinned_use_cuda_host_register = tokenizer.toBool(++i); -size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( - const std::vector& config, - size_t i) { - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - size_t val2 = stoi(config[i]); - TORCH_CHECK( - llvm::isPowerOf2_64(val2), - "Number of register threads has to be power of 2 ", - ""); - auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); - TORCH_CHECK( - val2 <= maxThreads, - "Number of register threads should be less than or equal to " + - std::to_string(maxThreads), - ""); - m_pinned_num_register_threads = val2; - } else { - TORCH_CHECK( - false, "Error, expecting pinned_num_register_threads value", ""); - } return i; } -size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( - const std::vector& config, +size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i) { - consumeToken(config, ++i, ':'); - if (++i < config.size()) { - TORCH_CHECK( - (config[i] == "True" || config[i] == "False"), - "Expected a single True/False argument for pinned_use_background_threads"); - m_pinned_use_background_threads = (config[i] == "True"); - } else { - TORCH_CHECK( - false, "Error, expecting pinned_use_background_threads value", ""); - } + tokenizer.checkToken(++i, ":"); + size_t val2 = tokenizer.toSizeT(++i); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; return i; } -// General caching allocator utilities -void setAllocatorSettings(const std::string& env) { - CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); -} +REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig) } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index fda3cc02e5d0a..f96ae5e56ba6c 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -1,16 +1,11 @@ #pragma once +#include +#include #include #include #include -#include -#include -#include -#include -#include -#include - namespace c10::cuda::CUDACachingAllocator { enum class Expandable_Segments_Handle_Type : int { @@ -23,20 +18,23 @@ enum class Expandable_Segments_Handle_Type : int { class C10_CUDA_API CUDAAllocatorConfig { public: static size_t max_split_size() { - return instance().m_max_split_size; + return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); } static double garbage_collection_threshold() { - return instance().m_garbage_collection_threshold; + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + garbage_collection_threshold(); } static bool expandable_segments() { + bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: + use_expandable_segments(); #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED - if (instance().m_expandable_segments) { + if (enabled) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } return false; #else - return instance().m_expandable_segments; + return enabled; #endif } @@ -63,7 +61,8 @@ class C10_CUDA_API CUDAAllocatorConfig { } static bool pinned_use_background_threads() { - return instance().m_pinned_use_background_threads; + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + pinned_use_background_threads(); } static size_t pinned_max_register_threads() { @@ -77,88 +76,97 @@ class C10_CUDA_API CUDAAllocatorConfig { // More description below in function roundup_power2_next_division // As an example, if we want 4 divisions between 2's power, this can be done // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 - static size_t roundup_power2_divisions(size_t size); + static size_t roundup_power2_divisions(size_t size) { + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + roundup_power2_divisions(size); + } static std::vector roundup_power2_divisions() { - return instance().m_roundup_power2_divisions; + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + roundup_power2_divisions(); } static size_t max_non_split_rounding_size() { - return instance().m_max_non_split_rounding_size; + return c10::CachingAllocator::AcceleratorAllocatorConfig:: + max_non_split_rounding_size(); } static std::string last_allocator_settings() { - std::lock_guard lock( - instance().m_last_allocator_settings_mutex); - return instance().m_last_allocator_settings; + return c10::CachingAllocator::getAllocatorSettings(); + } + + static bool use_async_allocator() { + return instance().m_use_async_allocator; + } + + static const std::unordered_set& getKeys() { + return instance().keys_; } static CUDAAllocatorConfig& instance() { static CUDAAllocatorConfig* s_instance = ([]() { auto inst = new CUDAAllocatorConfig(); - auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); + auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF"); + if (!env.has_value()) { + // For backward compatibility, check for the old environment variable + // PYTORCH_CUDA_ALLOC_CONF. + env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); + } #ifdef USE_ROCM // convenience for ROCm users, allow alternative HIP token if (!env.has_value()) { env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); } #endif - inst->parseArgs(env); + if (env.has_value()) { + inst->parseArgs(env.value()); + } return inst; })(); return *s_instance; } - void parseArgs(const std::optional& env); + void parseArgs(const std::string& env); private: - CUDAAllocatorConfig(); - - static void lexArgs(const std::string& env, std::vector& config); - static void consumeToken( - const std::vector& config, - size_t i, - const char c); - size_t parseMaxSplitSize(const std::vector& config, size_t i); - size_t parseMaxNonSplitRoundingSize( - const std::vector& config, - size_t i); - size_t parseGarbageCollectionThreshold( - const std::vector& config, - size_t i); - size_t parseRoundUpPower2Divisions( - const std::vector& config, - size_t i); + CUDAAllocatorConfig() = default; + size_t parseAllocatorConfig( - const std::vector& config, - size_t i, - bool& used_cudaMallocAsync); + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); size_t parsePinnedUseCudaHostRegister( - const std::vector& config, + const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); size_t parsePinnedNumRegisterThreads( - const std::vector& config, - size_t i); - size_t parsePinnedUseBackgroundThreads( - const std::vector& config, + const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); - std::atomic m_max_split_size; - std::atomic m_max_non_split_rounding_size; - std::vector m_roundup_power2_divisions; - std::atomic m_garbage_collection_threshold; - std::atomic m_pinned_num_register_threads; - std::atomic m_expandable_segments; - std::atomic - m_expandable_segments_handle_type; - std::atomic m_release_lock_on_cudamalloc; - std::atomic m_pinned_use_cuda_host_register; - std::atomic m_pinned_use_background_threads; - std::string m_last_allocator_settings; - std::mutex m_last_allocator_settings_mutex; + std::atomic m_pinned_num_register_threads{1}; + std::atomic m_expandable_segments_handle_type +#if CUDA_VERSION >= 12030 + {Expandable_Segments_Handle_Type::UNSPECIFIED}; +#else + {Expandable_Segments_Handle_Type::POSIX_FD}; +#endif + std::atomic m_release_lock_on_cudamalloc{false}; + std::atomic m_pinned_use_cuda_host_register{false}; + std::atomic m_use_async_allocator{false}; + std::atomic m_is_allocator_loaded{false}; + std::unordered_set keys_{ + "backend", + // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues + // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_cud" + "amalloc", + "pinned_use_cud" + "a_host_register", + // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) + "release_lock_on_hipmalloc", + "pinned_use_hip_host_register", + "pinned_num_register_threads"}; }; -// General caching allocator utilities -C10_CUDA_API void setAllocatorSettings(const std::string& env); +// Keep this for backwards compatibility +using c10::CachingAllocator::setAllocatorSettings; } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4d58c11c5c9bc..ed6914c350599 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1,7 +1,6 @@ #include #include -#include #include #include #include @@ -64,10 +63,6 @@ namespace cuda::CUDACachingAllocator { using namespace c10::CachingAllocator; using namespace c10::CachingDeviceAllocator; -// Included here as this is externally used in CUDAAllocatorConfig -const size_t kLargeBuffer = - 20971520; // "large" allocations may be packed in 20 MiB blocks - namespace Native { // @@ -4130,49 +4125,10 @@ CUDAAllocator* allocator(); } // namespace CudaMallocAsync struct BackendStaticInitializer { - // Parses env for backend at load time, duplicating some logic from - // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at - // runtime). Defers verbose exceptions and error checks, including Cuda - // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this - // works, maybe we should move all of CUDAAllocatorConfig here? CUDAAllocator* parseEnvForBackend() { - auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); -#ifdef USE_ROCM - // convenience for ROCm users to allow either CUDA or HIP env var - if (!val.has_value()) { - val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); - } -#endif - if (val.has_value()) { - const std::string& config = val.value(); - - std::regex exp("[\\s,]+"); - std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); - std::sregex_token_iterator end; - std::vector options(it, end); - - for (auto option : options) { - std::regex exp2("[:]+"); - std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); - std::sregex_token_iterator end2; - std::vector kv(it2, end2); - if (kv.size() >= 2) { - if (kv[0] == "backend") { -#ifdef USE_ROCM - // convenience for ROCm users to allow either CUDA or HIP env var - if (kv[1] == - "cud" - "aMallocAsync" || - kv[1] == "hipMallocAsync") -#else - if (kv[1] == "cudaMallocAsync") -#endif - return CudaMallocAsync::allocator(); - if (kv[1] == "native") - return &Native::allocator; - } - } - } + // If the environment variable is set, we use the CudaMallocAsync allocator. + if (CUDAAllocatorConfig::use_async_allocator()) { + return CudaMallocAsync::allocator(); } return &Native::allocator; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index a6fa61110d675..956411fe22827 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -49,10 +50,9 @@ namespace c10::cuda::CUDACachingAllocator { // Preserved only for BC reasons // NOLINTNEXTLINE(misc-unused-using-decls) +using c10::CachingAllocator::kLargeBuffer; using c10::CachingDeviceAllocator::DeviceStats; -extern const size_t kLargeBuffer; - typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a From e40ade5182233f548b25f2732effe3719d16e9ad Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 15 Jul 2025 10:11:48 +0000 Subject: [PATCH 044/457] Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156165 Approved by: https://github.com/albanD ghstack dependencies: #150312 --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 2 +- c10/cuda/CUDAAllocatorConfig.h | 19 +++++++-- c10/cuda/CUDACachingAllocator.cpp | 47 +++++++++++---------- c10/xpu/XPUCachingAllocator.cpp | 3 +- torch/csrc/cuda/Module.cpp | 5 +-- 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 6a80342e10240..b5e5f84cde13f 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl } bool pinned_use_background_threads() override { - return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + return c10::CachingAllocator::AcceleratorAllocatorConfig:: pinned_use_background_threads(); } diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index f96ae5e56ba6c..6254f85cd5b86 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -17,9 +18,13 @@ enum class Expandable_Segments_Handle_Type : int { // Environment config parser class C10_CUDA_API CUDAAllocatorConfig { public: + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.") static size_t max_split_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); } + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.") static double garbage_collection_threshold() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: garbage_collection_threshold(); @@ -60,6 +65,8 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.") static bool pinned_use_background_threads() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: pinned_use_background_threads(); @@ -72,25 +79,29 @@ class C10_CUDA_API CUDAAllocatorConfig { return 128; } - // This is used to round-up allocation size to nearest power of 2 divisions. - // More description below in function roundup_power2_next_division - // As an example, if we want 4 divisions between 2's power, this can be done - // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static size_t roundup_power2_divisions(size_t size) { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(size); } + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static std::vector roundup_power2_divisions() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(); } + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.") static size_t max_non_split_rounding_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: max_non_split_rounding_size(); } + C10_DEPRECATED_MESSAGE( + "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.") static std::string last_allocator_settings() { return c10::CachingAllocator::getAllocatorSettings(); } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index ed6914c350599..5ae04bcd3f53c 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1226,7 +1226,7 @@ class DeviceCachingAllocator { DeviceCachingAllocator() : large_blocks(/*small=*/false), small_blocks(/*small=*/true) { stats.max_split_size = - static_cast(CUDAAllocatorConfig::max_split_size()); + static_cast(AcceleratorAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); } @@ -1351,7 +1351,8 @@ class DeviceCachingAllocator { // Do garbage collection if the flag is set. if (C10_UNLIKELY( set_fraction && - CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { + AcceleratorAllocatorConfig::garbage_collection_threshold() > + 0.0)) { garbage_collect_cached_blocks(context); } // Attempt allocate @@ -1603,7 +1604,7 @@ class DeviceCachingAllocator { stats.active_bytes[stat_type].increase(block->size); stats.requested_bytes[stat_type].increase(block->requested_size); }); - if (block->size >= CUDAAllocatorConfig::max_split_size()) + if (block->size >= AcceleratorAllocatorConfig::max_split_size()) stats.oversize_allocations.increase(1); auto allocated_bytes_gauge = @@ -1654,7 +1655,7 @@ class DeviceCachingAllocator { block->pool->owner_MempoolId(), context ? context : block->context_when_allocated); - if (block->size >= CUDAAllocatorConfig::max_split_size()) + if (block->size >= AcceleratorAllocatorConfig::max_split_size()) stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { @@ -2204,7 +2205,8 @@ class DeviceCachingAllocator { if (size < kMinBlockSize) { return kMinBlockSize; } else { - auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size); + auto divisions = + AcceleratorAllocatorConfig::roundup_power2_divisions(size); if (divisions > 1 && size > (kMinBlockSize * divisions)) { return roundup_power2_next_division(size, divisions); } else { @@ -2694,7 +2696,7 @@ class DeviceCachingAllocator { if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) { return remaining >= kMinBlockSize; } else { - return (size < CUDAAllocatorConfig::max_split_size()) && + return (size < AcceleratorAllocatorConfig::max_split_size()) && (remaining > kSmallSize); } } @@ -2714,7 +2716,7 @@ class DeviceCachingAllocator { if (C10_UNLIKELY( set_fraction && - CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { + AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; } @@ -2756,13 +2758,13 @@ class DeviceCachingAllocator { } // Do not return an oversized block for a large request - if ((p.size() < CUDAAllocatorConfig::max_split_size()) && - ((*it)->size >= CUDAAllocatorConfig::max_split_size())) + if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) && + ((*it)->size >= AcceleratorAllocatorConfig::max_split_size())) return false; // Allow oversized block size to be rounded up but within a limit - if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && + if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) && ((*it)->size >= - p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) + p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); @@ -2785,7 +2787,7 @@ class DeviceCachingAllocator { // therefore should be of less overheads. size_t gc_threshold = static_cast( - CUDAAllocatorConfig::garbage_collection_threshold() * + AcceleratorAllocatorConfig::garbage_collection_threshold() * static_cast(allowed_memory_maximum)); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { @@ -2933,7 +2935,7 @@ class DeviceCachingAllocator { stats.segment[stat_type].increase(1); stats.reserved_bytes[stat_type].increase(size); }); - if (size >= CUDAAllocatorConfig::max_split_size()) + if (size >= AcceleratorAllocatorConfig::max_split_size()) stats.oversize_segments.increase(1); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2962,7 +2964,7 @@ class DeviceCachingAllocator { bool release_available_cached_blocks( const AllocParams& p, const std::shared_ptr& context) { - if (CUDAAllocatorConfig::max_split_size() == + if (AcceleratorAllocatorConfig::max_split_size() == std::numeric_limits::max()) return false; BlockPool& pool = *p.pool; @@ -2970,8 +2972,8 @@ class DeviceCachingAllocator { // because of std::unique_ptr, block cannot be trivially copied // Use constructor for search key. Block key(p.search_key.device, p.search_key.stream, p.search_key.size); - key.size = (key.size < CUDAAllocatorConfig::max_split_size()) - ? CUDAAllocatorConfig::max_split_size() + key.size = (key.size < AcceleratorAllocatorConfig::max_split_size()) + ? AcceleratorAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); if (it == pool.blocks.end() || (*it)->stream != p.stream() || @@ -2984,7 +2986,7 @@ class DeviceCachingAllocator { --it; // Back up one item. Now on the largest block for the correct // stream while ((totalReleased < key.size) && - ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && + ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; bool is_first = cur == pool.blocks.begin(); @@ -3109,7 +3111,7 @@ class DeviceCachingAllocator { stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current); - if (block->size >= CUDAAllocatorConfig::max_split_size()) + if (block->size >= AcceleratorAllocatorConfig::max_split_size()) stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; @@ -3736,8 +3738,8 @@ class NativeCachingAllocator : public CUDAAllocator { auto& md = result.config_metadata; md.garbage_collection_threshold = - CUDAAllocatorConfig::garbage_collection_threshold(); - md.max_split_size = CUDAAllocatorConfig::max_split_size(); + AcceleratorAllocatorConfig::garbage_collection_threshold(); + md.max_split_size = AcceleratorAllocatorConfig::max_split_size(); md.pinned_num_register_threads = CUDAAllocatorConfig::pinned_num_register_threads(); md.expandable_segments = CUDAAllocatorConfig::expandable_segments(); @@ -3745,9 +3747,10 @@ class NativeCachingAllocator : public CUDAAllocator { CUDAAllocatorConfig::release_lock_on_cudamalloc(); md.pinned_use_host_register = CUDAAllocatorConfig::pinned_use_cuda_host_register(); - md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings(); + md.last_allocator_settings = + AcceleratorAllocatorConfig::last_allocator_settings(); md.roundup_power2_divisions = - CUDAAllocatorConfig::roundup_power2_divisions(); + AcceleratorAllocatorConfig::roundup_power2_divisions(); return result; } diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 543b48f081135..afae32d92a4b4 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -20,8 +21,6 @@ constexpr size_t kMinBlockSize = 512; constexpr size_t kSmallSize = 1048576; // "small" allocations are packed in 2 MiB blocks constexpr size_t kSmallBuffer = 2097152; -// "large" allocations may be packed in 20 MiB blocks -constexpr size_t kLargeBuffer = 20971520; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kMinLargeAlloc = 10485760; // round up large allocations to 2 MiB diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index b44ce311ecd92..ead46337ff090 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -20,8 +20,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -426,8 +426,7 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings( PyObject* _unused, PyObject* env) { HANDLE_TH_ERRORS - c10::cuda::CUDACachingAllocator::setAllocatorSettings( - THPUtils_unpackString(env)); + c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } From 6200584193b770411b7f91880bbff6f746acfcb0 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Tue, 15 Jul 2025 01:55:15 +0000 Subject: [PATCH 045/457] [cutlass backend][BE] remove force disable cache in tests (#158053) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158053 Approved by: https://github.com/coconutruben --- test/inductor/test_cutlass_backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index c80bd13af361d..bb27b3d1a68c4 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -421,7 +421,9 @@ def forward(self, a, b, c): 2, 4, ], # guarantees > 1 choices - "force_disable_caches": True, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "autotune_local_cache": False, } ): from torch._inductor.utils import run_and_get_code @@ -1530,7 +1532,8 @@ def mm(a, b): "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", "cuda.cutlass_max_profiling_configs": 2, # needed for log searching - "force_disable_caches": True, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, } ): with ( From 156a377f4cf9b5b5255575e26d27f745c111a6ae Mon Sep 17 00:00:00 2001 From: "Xiangyang (Mark) Guo" Date: Tue, 15 Jul 2025 10:51:43 +0000 Subject: [PATCH 046/457] [AOTI][CPP] add flag TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL (#157949) Summary: Add flag TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL to force inline the kernel function when TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL=1. It's disabled by default because force inlining may increase the build time. Differential Revision: D77915987 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157949 Approved by: https://github.com/desertfire --- torch/_inductor/codegen/cpp.py | 5 ++++- torch/_inductor/config.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 06467f06fc028..4b15618c12bf0 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5282,8 +5282,11 @@ def codegen_group(self, name=None) -> str: arg_defs, _, _ = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) func_export_decl = get_export_declaration() + inline_attr = ( + "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else "" + ) code.writeline( - f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})' ) # 3. Function body diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 826324b6a2044..5c7a53683db3b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1004,6 +1004,11 @@ def decide_compile_threads() -> int: # config specific to codegen/cpp.py class cpp: + """ + Settings for cpp backend. + This class provides a centralized location for managing cpp backend settings. + """ + # set to torch.get_num_threads() threads = -1 @@ -1119,6 +1124,10 @@ class cpp: # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] use_small_dequant_buffer = False + force_inline_kernel = ( + os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1" + ) + class triton: """ From 4e13eca713c60ca63c1116823b99d2461a7422ef Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 15 Jul 2025 11:52:08 +0000 Subject: [PATCH 047/457] [BE] Remove CUDA 11.8 artifacts (#158303) We are including cufile by default in all CUDA 12+ builds. Since CUDA 11.8 is removed we can safely remove this code Pull Request resolved: https://github.com/pytorch/pytorch/pull/158303 Approved by: https://github.com/Camyll, https://github.com/cyyever --- torch/__init__.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 95459337c2eda..99cb83db84b81 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -367,14 +367,8 @@ def _load_global_deps() -> None: "nccl": "libnccl.so.*[0-9]", "nvtx": "libnvToolsExt.so.*[0-9]", "nvshmem": "libnvshmem_host.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", } - # cufiile is only available on cuda 12+ - # TODO: Remove once CUDA 11.8 binaries are deprecated - if cuda_version is not None: - t_version = cuda_version.split(".") - t_major = int(t_version[0]) # type: ignore[operator] - if t_major >= 12: - cuda_libs["cufile"] = "libcufile.so.*[0-9]" is_cuda_lib_err = [ lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] From 90618581e971d28ac6950305d72521af05ed3a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Tue, 15 Jul 2025 07:50:26 +0000 Subject: [PATCH 048/457] Fix grouped MM output strides when compiled but not max-autotuned (#158143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158143 Approved by: https://github.com/ngimel --- test/test_matmul_cuda.py | 17 ++++++++++++----- torch/_inductor/kernel/mm_grouped.py | 26 +++++++++++++++++++------- torch/_meta_registrations.py | 8 +++++--- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 33127c689e20e..31f36681bc3a4 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -496,7 +496,8 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) - def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): + @parametrize("max_autotune", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): torch._dynamo.reset() device = "cuda" @@ -506,12 +507,18 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): align = 16 // dtype_AB.itemsize f_ref = torch._grouped_mm + + options = {} + if max_autotune: + options.update( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + } + ) f = torch.compile( f_ref, - options={ - "max_autotune": True, - "max_autotune_gemm_backends": "TRITON", - }, + options=options, ) if op == "2d/2d": diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index d311b62950bd2..19ca389c2a53b 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -22,6 +22,7 @@ get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_triton_template, ) from .mm_common import ( _is_static_problem, @@ -434,23 +435,30 @@ def grouped_mm_args( if out_dtype is None: out_dtype = mat1.get_dtype() + alignment = 16 // out_dtype.itemsize - dims = [] if m1dim == 2: if m2dim == 2: assert offs is not None - dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + out_size = [offs.get_size()[0], mat1_size[0], mat2_size[1]] else: - dims = [mat1_size[0], mat2_size[-1]] + out_size = [mat1_size[0], mat2_size[-1]] else: if m2dim == 2: - dims = [mat1_size[1], mat2_size[1]] + out_size = [mat1_size[1], mat2_size[1]] else: - dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + out_size = [mat1_size[0], mat1_size[1], mat2_size[-1]] + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if len(out_size) == 2: + out_stride = [size_padded, 1] + else: + out_stride = [out_size[1] * size_padded, size_padded, 1] + layout = FixedLayout( mat1.get_device(), out_dtype, - dims, + out_size, + out_stride, ) else: assert out_dtype is None, "out_dtype is ignored if layout is specified." @@ -604,7 +612,11 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + if ( + is_nonzero + and use_triton_template(layout) + and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) + ): scaled = scale_a is not None if len(m1_size) == 2: if len(m2_size) == 2: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 4d8079d9a7618..ae87e0e17fb37 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -7502,18 +7502,20 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): out_size = [offs.size(0), mat1.size(0), mat2.size(1)] else: torch._check( - offs.size(0) == mat2.size(0), "matrix batch sizes have to match" + offs.size(0) == mat2.size(0), lambda: "matrix batch sizes have to match" ) out_size = [mat1.size(0), mat2.size(-1)] else: if mat2_is_2d: torch._check( - offs.size(0) == mat1.size(0), "matrix batch sizes have to match" + offs.size(0) == mat1.size(0), lambda: "matrix batch sizes have to match" ) out_size = [mat1.size(1), mat2.size(1)] else: # regular bmm - torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") + torch._check( + mat1.size(0) == mat2.size(0), lambda: "batched dimension has to match" + ) out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] out_dtype = out_dtype or mat1.dtype From 5a54db14e3843cfa87fd8d27487dbf2f2dfb6c47 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 15 Jul 2025 03:57:51 -0700 Subject: [PATCH 049/457] [simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062) Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062 Approved by: https://github.com/wconstab --- .../pr_time_benchmarks/expected_results.csv | 2 +- test/distributed/test_inductor_collectives.py | 76 ++- torch/_inductor/comms.py | 475 +++++++++++------- torch/_inductor/dependencies.py | 8 +- torch/_inductor/utils.py | 3 +- 5 files changed, 357 insertions(+), 207 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 9e5521f94b43e..7afa727a7ce48 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -4,7 +4,7 @@ add_loop_inductor,compile_time_instruction_count,33090000000,0.015 add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025 add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015 basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.015 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015 basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2 update_hint_regression,compile_time_instruction_count,1673000000,0.02 diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fad2f8195600c..1f09d72ea2b1a 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -19,6 +19,7 @@ from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, + sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.scheduler import BaseSchedulerNode @@ -1621,7 +1622,7 @@ def test_reorder_peak_memory_bucketed(self): comm from moving due to data dependency. """ - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1654,14 +1655,52 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): # wait op rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) + y += torch.mm(2 * x, 2 * w) + + # cast the inputs + ag_2_cast = ag_2.to(torch.bfloat16) + ag_3_cast = ag_3.to(torch.bfloat16) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2_cast, group_size, group_name + ) + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3_cast, group_size, group_name + ) + + # wait op + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + + # + rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_2_cast, "sum", group_size, group_name + ) + rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_3_cast, "sum", group_size, group_name + ) - return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out + # wait op + rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out) + rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out) + return ( + y, + ag_0_out, + ag_1_out, + ag_2_out, + ag_3_out, + rs_0_out, + rs_1_out, + rs_2_out, + rs_3_out, + ) x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] # get stats directly from the internal helper without affecting the real pass's signature node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None @@ -1679,11 +1718,15 @@ def _reorder_communication_preserving_peak_memory( with torch._inductor.config.patch( { "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], + "allow_buffer_reuse": False, } ): compiled = torch.compile(func) @@ -1694,30 +1737,29 @@ def _reorder_communication_preserving_peak_memory( FileCheck() .check_count( "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - count=1, + count=2, exactly=True, ) + .check( + "extern_kernels.mm", + ) + .check( + "extern_kernels.addmm", + ) .run(code) ) ( FileCheck() .check_count( "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - count=1, + count=2, exactly=True, ) - .run(code) - ) - ( - FileCheck() - .check( - "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - ) .check( - "torch.ops._c10d_functional.reduce_scatter_tensor.default(", + "extern_kernels.mm", ) .check( - "extern_kernels.mm", + "extern_kernels.addmm", ) .run(code) ) @@ -1726,7 +1768,7 @@ def _reorder_communication_preserving_peak_memory( assert same(out, correct), f"{out} va {correct}" assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) - self.assertEqual(len(node_stats), 2) + self.assertEqual(len(node_stats), 4) it = iter(node_stats.values()) node_stat0 = next(it) self.assertTrue(node_stat0.moves > 0) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index caaf43dba5904..7f31a2fc2e1d5 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,7 +4,6 @@ import heapq import importlib -import itertools import logging import operator import sys @@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: return True if ( - hasattr(node, "python_kernel_name") - and node.python_kernel_name == "extern_kernels.mm" - ): + python_kernel_name := getattr(node, "python_kernel_name", None) + ) and "extern_kernels" in python_kernel_name: return True return False @@ -189,15 +187,24 @@ def _group_name(snode, with_bufs=False) -> str: def _reorder_communication_preserving_peak_memory_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node - - original_snodes_num = len(snodes) """ Internal testing helper that also returns debug info. Returns: - reordered snodes list - dict {snode: ReorderInfo} """ + # Short circuit to not regress compilation time for non distributed cases. + has_collectives: bool = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) # heuristic to avoid degenerating to quadratic time graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) @@ -208,7 +215,8 @@ def _reorder_communication_preserving_peak_memory_internal( snodes, name_to_freeable_input_buf, graph_outputs ) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} - snode_to_curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] # debug stats stats: dict[BaseSchedulerNode, ReorderInfo] = {} @@ -232,153 +240,151 @@ def accumulate_time(_snode): _temp_group_visit_leaves(snode, accumulate_time) return max(0, comm_time - compute_time) - MOVE_LIMIT = len(snodes) * 100 total_moves = 0 - # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it - PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) - if config.reorder_prefetch_limit is not None: - PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit # Dicts to keep track of "next" and "previous" as double-linked structure during grouping - _prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} - _next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} for i, snode in enumerate(snodes): _prev[snode] = snodes[i - 1] if i > 0 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None - - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i, gsnode in enumerate(gsnodes): - snode = gsnode.snodes[0] # type: ignore[attr-defined] - if contains_collective(snode): - reorder_info = stats[snode] = ReorderInfo() + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + _head = snodes[0] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = _head + while _next[curr] is not None: + if contains_collective(curr): + reorder_info = stats[curr] = ReorderInfo() reorder_info.initial_exposed = reorder_info.final_exposed = ( - exposed_communication_time(snode, snodes[i + 1 :]) + exposed_communication_time(curr, _group_nodes(_next[curr], None)) ) - if total_moves >= MOVE_LIMIT: - reorder_info.limiting_factor = "move limit" - continue - for j in range(i - 1, -1, -1): - prev_gsnode = gsnodes[j] - if len(prev_gsnode.snodes) == 0: - continue - - if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): - reorder_info.limiting_factor = "prefetch limit" - break - if contains_collective(prev_gsnode): + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + if contains_collective(candidate): reorder_info.limiting_factor = "collective ordering" break - dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) - prev_outs = prev_gsnode.get_outputs() + group = GroupedSchedulerNode( + curr.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + + data_deps = {s.name: s for s in group.unmet_dependencies} + candidate_outs = candidate.get_outputs() data_dep = None - for o in prev_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break if data_dep is not None: - def is_groupable(prev_gsnode): + def is_groupable(candidate): # preserve ordering - if contains_collective(prev_gsnode): - return False - - if contains_gemm_like(prev_gsnode): - return False - return True - - if is_groupable(prev_gsnode): - new_snodes = prev_gsnode.snodes + gsnode.snodes - init_group_node(gsnode, gsnode.scheduler, new_snodes) - prev_gsnode.snodes = [] + if contains_collective(candidate): + return False, "contains_collective" + + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_head = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) reorder_info.grouped += 1 - reorder_info.grouped_info = gsnode.get_name() + reorder_info.grouped_info = _group_names(group_head, group_tail) + candidate = _prev[candidate] continue else: msg = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n non_group_reason:{grp_reason}" ) reorder_info.limiting_factor = msg break - if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + delta_memory_candidate = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + + if group_peak_memory - delta_memory_candidate > peak_memory: reorder_info.limiting_factor = "peak memory" break - if reorder_info.final_exposed > runtimes[snode]: - reorder_info.limiting_factor = "sufficient overlapping" - break + reorder_info.moves += 1 total_moves += 1 - # swapping nodes j and j+1 affects curr memory at j only - # j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] - # j_alloc = curr_memory[j] - curr_memory[j - 1] - # curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc - def swap_curr_memory_with_previous( - snode_j_plus_one, snode_j, snode_j_minus_one - ): - curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] - curr_memory_j = snode_to_curr_memory[snode_j] - curr_memory_j_minus_one = ( - snode_to_curr_memory[snode_j_minus_one] - if snode_j_minus_one is not None - else 0 - ) - j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j - j_alloc = curr_memory_j - curr_memory_j_minus_one - snode_to_curr_memory[snode_j] = ( - curr_memory_j - j_alloc + j_plus_one_alloc - ) - - # Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) - # swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] - # decomposing to: - # swap(A2, B0) -> A0, A1, B0, A2, B1 - # swap(A2, B1) -> A0, A1, B0, B1, A2 - # swap(A1, B0) -> A0, B0, A1, B1, A2 - # swap(A1, B1) -> A0, B0, B1, A1, A2 - # swap(A0, B0) -> B0, A0, B1, A1, A2 - # swap(A0, B1) -> B0, B1, A0, A1, A2 - for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A - snode_j = gsnodes[j].snodes[_j] - for _i, snode_i in enumerate(gsnode.snodes): # group B - swap_curr_memory_with_previous( - snode_j_plus_one=snode_i, - snode_j=snode_j, - snode_j_minus_one=_prev[snode_j], - ) + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # swap (candidate, group_head...group_tail) + # Before: + # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next + # After: + # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next + # 0 + candidate_prev = _prev[candidate] + if candidate_prev: + _next[candidate_prev] = group_head + _prev[group_head] = candidate_prev + + # 2 + group_tail_next = _next[group_tail] + if group_tail_next: + _prev[group_tail_next] = candidate + _next[candidate] = group_tail_next + + # 1 + _prev[candidate] = group_tail + _next[group_tail] = candidate + + if _head == candidate: + _head = group_head - # Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] - first = snode_j - second = snode_i - first_prev = _prev[first] - second_next = _next[second] - if first_prev: - _next[first_prev] = second - _prev[second] = first_prev - - if second_next: - _prev[second_next] = first - _next[first] = second_next - - _next[second] = first - _prev[first] = second - - tmp = gsnodes[j] - gsnodes[j] = gsnodes[j + 1] - gsnodes[j + 1] = tmp reorder_info.final_exposed = exposed_communication_time( - snode, - itertools.chain( - gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]] - ), + curr, _group_nodes(_next[curr], None) ) + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] + for n in _group_nodes(group_head, candidate): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + candidate = _prev[group_head] + curr = _next[curr] # type: ignore[assignment] node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} @@ -426,17 +432,13 @@ def swap_curr_memory_with_previous( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - grouping_logs: list[str] = [] - flatten_gsnodes: list[BaseSchedulerNode] = [] - for i, gsnode in enumerate(gsnodes): - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_gsnodes.extend(gsnode.snodes) - else: - flatten_gsnodes.append(gsnode) - - grouping_log_str = "\n".join(grouping_logs) - reorder_log_str += "\n" - reorder_log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + reorder_log_str += f"\n peak_memory_before:{peak_memory}" + reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" overlap_log.info(reorder_log_str) trace_structured( @@ -448,8 +450,7 @@ def swap_curr_memory_with_previous( payload_fn=lambda: reorder_log_str, ) - assert len(flatten_gsnodes) == original_snodes_num - return flatten_gsnodes, stats + return new_snodes, stats def _schedule_for_comm( @@ -623,7 +624,9 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): - comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + comm_nodes[i].add_fake_dep( + WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) + ) return nodes @@ -640,66 +643,166 @@ class SinkWaitInfo: def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) - n = len(snodes) stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i in range(n - 1, -1, -1): - gsnode = gsnodes[i] - if contains_wait(gsnode): - info = stats[gsnode.snodes[0]] = SinkWaitInfo() - for j in range(i + 1, n): - wait_gsnode = gsnodes[j - 1] - wait_outs = wait_gsnode.get_outputs() - next_gsnode = gsnodes[j] - dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _head = snodes[0] + for i, snode in enumerate(snodes): + _prev[snode] = snodes[i - 1] if i > 0 else None + _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + while _prev[curr] is not None: + if contains_wait(curr) and curr not in processed_waits: + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + group = GroupedSchedulerNode( + wait_snode.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + group_outs = group.get_outputs() + + data_deps = {s.name: s for s in candidate.unmet_dependencies} data_dep = None - for o in wait_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break # 1. If we have data_dep - we can not swap => trying to group # 2. If swap candidate and current node both contain collectives => trying to group if data_dep is not None or ( both_contain_comms := ( - contains_collective(wait_gsnode) - and contains_collective(next_gsnode) + contains_collective(group) and contains_collective(candidate) ) ): def is_groupable(snode): - return not contains_gemm_like(snode) - - if is_groupable(next_gsnode): - new_snodes = wait_gsnode.snodes + next_gsnode.snodes - init_group_node(next_gsnode, gsnode.scheduler, new_snodes) - wait_gsnode.snodes = [] + # We do not want to group with collectives to not reorder them forward. + if contains_collective(snode): + return ( + False, + f"candidate contains collective {snode.get_name()}", + ) + if contains_gemm_like(snode): + return ( + False, + f"candidate contains gemm_like {snode.get_name()}", + ) + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_tail = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) info.grouped += 1 - info.grouped_info = _group_name(next_gsnode) + info.grouped_info = _group_names(group_head, group_tail) + candidate = _next[candidate] continue elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( - f"collective ordering {_group_name(wait_gsnode)}" - f" with candidate:{_group_name(next_gsnode)}" + f"collective ordering {_group_names(group_head, group_tail)}" + f" with candidate:{candidate.get_name()}" ) + break else: info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}" - f" outs:{[o.get_name() for o in wait_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{grp_reason}" ) break + candidate_delta_memory = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + if group_peak_memory + candidate_delta_memory > peak_memory: + info.limiting_factor = "peak_memory" + break + info.moves += 1 - info.moves_info += f"+{_group_name(next_gsnode)}" + info.moves_info += f"+{candidate.get_name()}" + + # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # 0: + group_head_prev = _prev[group_head] + if group_head_prev: + _next[group_head_prev] = candidate + _prev[candidate] = group_head_prev + + # 2: + candidate_next = _next[candidate] + if candidate_next: + _prev[candidate_next] = group_tail + _next[group_tail] = candidate_next + + # 1: + _prev[group_head] = candidate + _next[candidate] = group_head + if group_head == _head: + _head = candidate + + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] + for n in _group_nodes(candidate, group_tail): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + + candidate = _next[group_tail] + curr = _prev[curr] # type: ignore[assignment] - # Swapping snodes j and j - 1 - tmp = gsnodes[j - 1] - gsnodes[j - 1] = gsnodes[j] - gsnodes[j] = tmp headers = [ "Wait node", "grouped", @@ -732,16 +835,13 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - grouping_logs = [] - flatten_snodes = [] - for i, gsnode in enumerate(gsnodes): - grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_snodes.extend(gsnode.snodes) - else: - flatten_snodes.append(gsnode) - grouping_log_str = "\n".join(grouping_logs) - log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + log_str += f"\n peak_memory_before:{peak_memory}" + log_str += f"\n peak_memory_after:{new_peak_memory}" trace_structured( "artifact", metadata_fn=lambda: { @@ -750,8 +850,7 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - assert len(flatten_snodes) == n - return flatten_snodes, stats + return new_snodes, stats def sink_waits_iterative( @@ -777,7 +876,9 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): - detail = f" ({snode.node.python_kernel_name})" + outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" + ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1352,7 +1453,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_ag_wait = wait_group_node @@ -1364,7 +1465,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 9de52061c6489..8a374f5bab35c 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -342,6 +342,12 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str + # WeakDep's are also used to add dependencies to prevent some specific reordering, + # E.g. collectives global ordering. + # But if other pass guarantees proper ordering by its logic, + # This additional "fake" deps will be holding optimizations. + # This flag is used to identify those additional deps. + is_fake: bool = False @property def index(self) -> sympy.Expr: @@ -352,7 +358,7 @@ def get_numel(self) -> sympy.Expr: def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: - return WeakDep(renames[self.name], self.mutating_buf) + return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) return self def numbytes_hint(self) -> int: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d22d67cecff21..a9aa28bb47508 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2346,7 +2346,7 @@ def is_collective( from . import ir - return ( + ret = ( isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel) and (op is None or node.op_overload is op) @@ -2373,6 +2373,7 @@ def is_collective( ) ) ) + return ret def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: From 05d7288e316ae5c9c661c4529f9f130a46263e5b Mon Sep 17 00:00:00 2001 From: dsashidh Date: Tue, 15 Jul 2025 16:25:01 +0000 Subject: [PATCH 050/457] Fix incorrect bin edge description in histogramdd docs (#158275) Fixes #124435 This updates the torch.histogramdd documentation to correctly state that bins are inclusive of their left edges, not exclusive as currently written. There was a previous PR addressing this but it was closed due to inactivity. This picks that up and applies the fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158275 Approved by: https://github.com/albanD --- torch/_torch_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9a96f3c097a51..0766bf7742864 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5128,7 +5128,7 @@ def merge_dicts(*dicts): If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences of bin edges. Each 1D tensor should contain a strictly increasing sequence with at least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying -the left and right edges of all bins. Every bin is exclusive of its left edge. Only +the left and right edges of all bins. Every bin is inclusive of its left edge. Only the rightmost bin is inclusive of its right edge. If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins From 4f36743f5eef2d9c40357eb5d8d8b1aeeacfbb2a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 16:31:13 +0000 Subject: [PATCH 051/457] Revert "[simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062)" This reverts commit 5a54db14e3843cfa87fd8d27487dbf2f2dfb6c47. Reverted https://github.com/pytorch/pytorch/pull/158062 on behalf of https://github.com/clee2000 due to sorry I want to revert something else and this is causing a merge conflict, all you should need to do is rebase and remerged ([comment](https://github.com/pytorch/pytorch/pull/158062#issuecomment-3074342140)) --- .../pr_time_benchmarks/expected_results.csv | 2 +- test/distributed/test_inductor_collectives.py | 76 +-- torch/_inductor/comms.py | 475 +++++++----------- torch/_inductor/dependencies.py | 8 +- torch/_inductor/utils.py | 3 +- 5 files changed, 207 insertions(+), 357 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 7afa727a7ce48..9e5521f94b43e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -4,7 +4,7 @@ add_loop_inductor,compile_time_instruction_count,33090000000,0.015 add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025 add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015 basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015 basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2 update_hint_regression,compile_time_instruction_count,1673000000,0.02 diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 1f09d72ea2b1a..fad2f8195600c 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -19,7 +19,6 @@ from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, - sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.scheduler import BaseSchedulerNode @@ -1622,7 +1621,7 @@ def test_reorder_peak_memory_bucketed(self): comm from moving due to data dependency. """ - def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1655,52 +1654,14 @@ def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # wait op rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) - y += torch.mm(2 * x, 2 * w) - - # cast the inputs - ag_2_cast = ag_2.to(torch.bfloat16) - ag_3_cast = ag_3.to(torch.bfloat16) - ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( - ag_2_cast, group_size, group_name - ) - ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( - ag_3_cast, group_size, group_name - ) - - # wait op - ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) - ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) - - # - rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor( - ag_2_cast, "sum", group_size, group_name - ) - rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor( - ag_3_cast, "sum", group_size, group_name - ) - # wait op - rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out) - rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out) - return ( - y, - ag_0_out, - ag_1_out, - ag_2_out, - ag_3_out, - rs_0_out, - rs_1_out, - rs_2_out, - rs_3_out, - ) + return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) - ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) - ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) - ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1, ag_2, ag_3] + ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1] # get stats directly from the internal helper without affecting the real pass's signature node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None @@ -1718,15 +1679,11 @@ def _reorder_communication_preserving_peak_memory( with torch._inductor.config.patch( { "bucket_all_gathers_fx": "all", - "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all", - "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ - sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], - "allow_buffer_reuse": False, } ): compiled = torch.compile(func) @@ -1737,29 +1694,30 @@ def _reorder_communication_preserving_peak_memory( FileCheck() .check_count( "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - count=2, + count=1, exactly=True, ) - .check( - "extern_kernels.mm", - ) - .check( - "extern_kernels.addmm", - ) .run(code) ) ( FileCheck() .check_count( "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - count=2, + count=1, exactly=True, ) + .run(code) + ) + ( + FileCheck() + .check( + "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", + ) .check( - "extern_kernels.mm", + "torch.ops._c10d_functional.reduce_scatter_tensor.default(", ) .check( - "extern_kernels.addmm", + "extern_kernels.mm", ) .run(code) ) @@ -1768,7 +1726,7 @@ def _reorder_communication_preserving_peak_memory( assert same(out, correct), f"{out} va {correct}" assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) - self.assertEqual(len(node_stats), 4) + self.assertEqual(len(node_stats), 2) it = iter(node_stats.values()) node_stat0 = next(it) self.assertTrue(node_stat0.moves > 0) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 7f31a2fc2e1d5..caaf43dba5904 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,6 +4,7 @@ import heapq import importlib +import itertools import logging import operator import sys @@ -148,8 +149,9 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: return True if ( - python_kernel_name := getattr(node, "python_kernel_name", None) - ) and "extern_kernels" in python_kernel_name: + hasattr(node, "python_kernel_name") + and node.python_kernel_name == "extern_kernels.mm" + ): return True return False @@ -187,24 +189,15 @@ def _group_name(snode, with_bufs=False) -> str: def _reorder_communication_preserving_peak_memory_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + + original_snodes_num = len(snodes) """ Internal testing helper that also returns debug info. Returns: - reordered snodes list - dict {snode: ReorderInfo} """ - # Short circuit to not regress compilation time for non distributed cases. - has_collectives: bool = False - for snode in snodes: - if contains_collective(snode): - has_collectives = True - break - if not has_collectives: - return snodes, {} - - from torch._inductor.scheduler import GroupedSchedulerNode - - original_snodes_num = len(snodes) # heuristic to avoid degenerating to quadratic time graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) @@ -215,8 +208,7 @@ def _reorder_communication_preserving_peak_memory_internal( snodes, name_to_freeable_input_buf, graph_outputs ) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} - _curr_memory = dict(zip(snodes, curr_memory)) - _curr_memory[None] = 0 # type: ignore[index] + snode_to_curr_memory = dict(zip(snodes, curr_memory)) # debug stats stats: dict[BaseSchedulerNode, ReorderInfo] = {} @@ -240,151 +232,153 @@ def accumulate_time(_snode): _temp_group_visit_leaves(snode, accumulate_time) return max(0, comm_time - compute_time) + MOVE_LIMIT = len(snodes) * 100 total_moves = 0 + # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it + PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) + if config.reorder_prefetch_limit is not None: + PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit # Dicts to keep track of "next" and "previous" as double-linked structure during grouping - _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} - _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} + _next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} for i, snode in enumerate(snodes): _prev[snode] = snodes[i - 1] if i > 0 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None - _curr_memory = dict(zip(snodes, curr_memory)) - _curr_memory[None] = 0 # type: ignore[index] - - _head = snodes[0] - - def _group_nodes(head, tail): - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] - return ret - - def _group_names(head, tail): - ret = "" - for n in _group_nodes(head, tail): - if ret: - ret += "~" - ret += n.get_name() - return ret - - curr = _head - while _next[curr] is not None: - if contains_collective(curr): - reorder_info = stats[curr] = ReorderInfo() + + gsnodes: list[GroupedSchedulerNode] = [ + GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) + for snode in snodes + ] + for i, gsnode in enumerate(gsnodes): + snode = gsnode.snodes[0] # type: ignore[attr-defined] + if contains_collective(snode): + reorder_info = stats[snode] = ReorderInfo() reorder_info.initial_exposed = reorder_info.final_exposed = ( - exposed_communication_time(curr, _group_nodes(_next[curr], None)) + exposed_communication_time(snode, snodes[i + 1 :]) ) + if total_moves >= MOVE_LIMIT: + reorder_info.limiting_factor = "move limit" + continue - candidate = _prev[curr] - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr] - while candidate is not None: - if contains_collective(candidate): + for j in range(i - 1, -1, -1): + prev_gsnode = gsnodes[j] + if len(prev_gsnode.snodes) == 0: + continue + + if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): + reorder_info.limiting_factor = "prefetch limit" + break + if contains_collective(prev_gsnode): reorder_info.limiting_factor = "collective ordering" break - group = GroupedSchedulerNode( - curr.scheduler, - _group_nodes(group_head, group_tail), - temp_grouping=True, - ) - - data_deps = {s.name: s for s in group.unmet_dependencies} - candidate_outs = candidate.get_outputs() + dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) + prev_outs = prev_gsnode.get_outputs() data_dep = None - for o in candidate_outs: - if d := data_deps.get(o.get_name(), None): - if isinstance(d, WeakDep) and d.is_fake: - continue - data_dep = d + for o in prev_outs: + if o.get_name() in dep_names: + data_dep = o.get_name() break if data_dep is not None: - def is_groupable(candidate): + def is_groupable(prev_gsnode): # preserve ordering - if contains_collective(candidate): - return False, "contains_collective" - - if contains_gemm_like(candidate): - return False, "contains_gemm_like" - return True, None - - is_grp, grp_reason = is_groupable(candidate) - if is_grp: - group_head = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate] - ) + if contains_collective(prev_gsnode): + return False + + if contains_gemm_like(prev_gsnode): + return False + return True + + if is_groupable(prev_gsnode): + new_snodes = prev_gsnode.snodes + gsnode.snodes + init_group_node(gsnode, gsnode.scheduler, new_snodes) + prev_gsnode.snodes = [] reorder_info.grouped += 1 - reorder_info.grouped_info = _group_names(group_head, group_tail) - candidate = _prev[candidate] + reorder_info.grouped_info = gsnode.get_name() continue else: msg = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" - f"dep on {_group_names(group_head, group_tail)}" - f"\n non_group_reason:{grp_reason}" + f"data dependency {data_dep}(dep_names:{dep_names})" + f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" ) reorder_info.limiting_factor = msg break - delta_memory_candidate = ( - _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] - ) - - if group_peak_memory - delta_memory_candidate > peak_memory: + if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: reorder_info.limiting_factor = "peak memory" break - + if reorder_info.final_exposed > runtimes[snode]: + reorder_info.limiting_factor = "sufficient overlapping" + break reorder_info.moves += 1 total_moves += 1 - mem_deltas = {} - for n in [candidate, *_group_nodes(group_head, group_tail)]: - mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] - # swap (candidate, group_head...group_tail) - # Before: - # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next - # After: - # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next - # 0 - candidate_prev = _prev[candidate] - if candidate_prev: - _next[candidate_prev] = group_head - _prev[group_head] = candidate_prev - - # 2 - group_tail_next = _next[group_tail] - if group_tail_next: - _prev[group_tail_next] = candidate - _next[candidate] = group_tail_next - - # 1 - _prev[candidate] = group_tail - _next[group_tail] = candidate - - if _head == candidate: - _head = group_head + # swapping nodes j and j+1 affects curr memory at j only + # j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] + # j_alloc = curr_memory[j] - curr_memory[j - 1] + # curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc + def swap_curr_memory_with_previous( + snode_j_plus_one, snode_j, snode_j_minus_one + ): + curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] + curr_memory_j = snode_to_curr_memory[snode_j] + curr_memory_j_minus_one = ( + snode_to_curr_memory[snode_j_minus_one] + if snode_j_minus_one is not None + else 0 + ) + j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j + j_alloc = curr_memory_j - curr_memory_j_minus_one + snode_to_curr_memory[snode_j] = ( + curr_memory_j - j_alloc + j_plus_one_alloc + ) + + # Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) + # swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] + # decomposing to: + # swap(A2, B0) -> A0, A1, B0, A2, B1 + # swap(A2, B1) -> A0, A1, B0, B1, A2 + # swap(A1, B0) -> A0, B0, A1, B1, A2 + # swap(A1, B1) -> A0, B0, B1, A1, A2 + # swap(A0, B0) -> B0, A0, B1, A1, A2 + # swap(A0, B1) -> B0, B1, A0, A1, A2 + for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A + snode_j = gsnodes[j].snodes[_j] + for _i, snode_i in enumerate(gsnode.snodes): # group B + swap_curr_memory_with_previous( + snode_j_plus_one=snode_i, + snode_j=snode_j, + snode_j_minus_one=_prev[snode_j], + ) + # Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] + first = snode_j + second = snode_i + first_prev = _prev[first] + second_next = _next[second] + if first_prev: + _next[first_prev] = second + _prev[second] = first_prev + + if second_next: + _prev[second_next] = first + _next[first] = second_next + + _next[second] = first + _prev[first] = second + + tmp = gsnodes[j] + gsnodes[j] = gsnodes[j + 1] + gsnodes[j + 1] = tmp reorder_info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) + snode, + itertools.chain( + gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]] + ), ) - # Recompute curr_memory - _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] - for n in _group_nodes(group_head, candidate): - _curr_memory[n] = _prev_curr_memory = ( - _prev_curr_memory + mem_deltas[n] - ) - candidate = _prev[group_head] - curr = _next[curr] # type: ignore[assignment] node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} @@ -432,13 +426,17 @@ def is_groupable(candidate): reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - new_snodes = _group_nodes(_head, None) - assert len(new_snodes) == original_snodes_num - new_peak_memory, curr_memory = estimate_peak_memory( - new_snodes, name_to_freeable_input_buf, graph_outputs - ) - reorder_log_str += f"\n peak_memory_before:{peak_memory}" - reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" + grouping_logs: list[str] = [] + flatten_gsnodes: list[BaseSchedulerNode] = [] + for i, gsnode in enumerate(gsnodes): + if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: + flatten_gsnodes.extend(gsnode.snodes) + else: + flatten_gsnodes.append(gsnode) + + grouping_log_str = "\n".join(grouping_logs) + reorder_log_str += "\n" + reorder_log_str += grouping_log_str overlap_log.info(reorder_log_str) trace_structured( @@ -450,7 +448,8 @@ def is_groupable(candidate): payload_fn=lambda: reorder_log_str, ) - return new_snodes, stats + assert len(flatten_gsnodes) == original_snodes_num + return flatten_gsnodes, stats def _schedule_for_comm( @@ -624,9 +623,7 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): - comm_nodes[i].add_fake_dep( - WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) - ) + comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) return nodes @@ -643,166 +640,66 @@ class SinkWaitInfo: def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode - - original_snodes_num = len(snodes) - if original_snodes_num == 0: - return snodes, {} - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( - snodes, graph_inputs - ) - peak_memory, curr_memory = estimate_peak_memory( - snodes, name_to_freeable_input_buf, graph_outputs - ) + from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + n = len(snodes) stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} - _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} - _head = snodes[0] - for i, snode in enumerate(snodes): - _prev[snode] = snodes[i - 1] if i > 0 else None - _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None - _curr_memory = dict(zip(snodes, curr_memory)) - _curr_memory[None] = 0 # type: ignore[index] - - def _group_nodes(head, tail): - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] - return ret - - def _group_names(head, tail): - ret = "" - for n in _group_nodes(head, tail): - if ret: - ret += "~" - ret += n.get_name() - return ret - - curr = snodes[-1] - - processed_waits = OrderedSet() # type: ignore[var-annotated] - while _prev[curr] is not None: - if contains_wait(curr) and curr not in processed_waits: - processed_waits.add(curr) - info = stats[curr] = SinkWaitInfo() - candidate = _next[curr] - wait_snode = curr - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr] - while candidate is not None: - group = GroupedSchedulerNode( - wait_snode.scheduler, - _group_nodes(group_head, group_tail), - temp_grouping=True, - ) - group_outs = group.get_outputs() - - data_deps = {s.name: s for s in candidate.unmet_dependencies} + gsnodes: list[GroupedSchedulerNode] = [ + GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) + for snode in snodes + ] + for i in range(n - 1, -1, -1): + gsnode = gsnodes[i] + if contains_wait(gsnode): + info = stats[gsnode.snodes[0]] = SinkWaitInfo() + for j in range(i + 1, n): + wait_gsnode = gsnodes[j - 1] + wait_outs = wait_gsnode.get_outputs() + next_gsnode = gsnodes[j] + dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) data_dep = None - for o in group_outs: - if d := data_deps.get(o.get_name(), None): - if isinstance(d, WeakDep) and d.is_fake: - continue - data_dep = d + for o in wait_outs: + if o.get_name() in dep_names: + data_dep = o.get_name() break # 1. If we have data_dep - we can not swap => trying to group # 2. If swap candidate and current node both contain collectives => trying to group if data_dep is not None or ( both_contain_comms := ( - contains_collective(group) and contains_collective(candidate) + contains_collective(wait_gsnode) + and contains_collective(next_gsnode) ) ): def is_groupable(snode): - # We do not want to group with collectives to not reorder them forward. - if contains_collective(snode): - return ( - False, - f"candidate contains collective {snode.get_name()}", - ) - if contains_gemm_like(snode): - return ( - False, - f"candidate contains gemm_like {snode.get_name()}", - ) - return True, None - - is_grp, grp_reason = is_groupable(candidate) - if is_grp: - group_tail = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate] - ) + return not contains_gemm_like(snode) + + if is_groupable(next_gsnode): + new_snodes = wait_gsnode.snodes + next_gsnode.snodes + init_group_node(next_gsnode, gsnode.scheduler, new_snodes) + wait_gsnode.snodes = [] info.grouped += 1 - info.grouped_info = _group_names(group_head, group_tail) - candidate = _next[candidate] + info.grouped_info = _group_name(next_gsnode) continue elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( - f"collective ordering {_group_names(group_head, group_tail)}" - f" with candidate:{candidate.get_name()}" + f"collective ordering {_group_name(wait_gsnode)}" + f" with candidate:{_group_name(next_gsnode)}" ) - break else: info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" - f"dep on {_group_names(group_head, group_tail)}" - f"\n outs:{[o.get_name() for o in group_outs]}" - f"\n non_group_reason:{grp_reason}" + f"data dependency {data_dep}(dep_names:{dep_names})" + f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}" + f" outs:{[o.get_name() for o in wait_outs]}" ) break - candidate_delta_memory = ( - _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] - ) - if group_peak_memory + candidate_delta_memory > peak_memory: - info.limiting_factor = "peak_memory" - break - info.moves += 1 - info.moves_info += f"+{candidate.get_name()}" - - # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next - mem_deltas = {} - for n in [candidate, *_group_nodes(group_head, group_tail)]: - mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] - # 0: - group_head_prev = _prev[group_head] - if group_head_prev: - _next[group_head_prev] = candidate - _prev[candidate] = group_head_prev - - # 2: - candidate_next = _next[candidate] - if candidate_next: - _prev[candidate_next] = group_tail - _next[group_tail] = candidate_next - - # 1: - _prev[group_head] = candidate - _next[candidate] = group_head - if group_head == _head: - _head = candidate - - # Recompute curr_memory - _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] - for n in _group_nodes(candidate, group_tail): - _curr_memory[n] = _prev_curr_memory = ( - _prev_curr_memory + mem_deltas[n] - ) - - candidate = _next[group_tail] - curr = _prev[curr] # type: ignore[assignment] + info.moves_info += f"+{_group_name(next_gsnode)}" + # Swapping snodes j and j - 1 + tmp = gsnodes[j - 1] + gsnodes[j - 1] = gsnodes[j] + gsnodes[j] = tmp headers = [ "Wait node", "grouped", @@ -835,13 +732,16 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes(_head, None) - assert len(new_snodes) == original_snodes_num - new_peak_memory, curr_memory = estimate_peak_memory( - new_snodes, name_to_freeable_input_buf, graph_outputs - ) - log_str += f"\n peak_memory_before:{peak_memory}" - log_str += f"\n peak_memory_after:{new_peak_memory}" + grouping_logs = [] + flatten_snodes = [] + for i, gsnode in enumerate(gsnodes): + grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") + if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: + flatten_snodes.extend(gsnode.snodes) + else: + flatten_snodes.append(gsnode) + grouping_log_str = "\n".join(grouping_logs) + log_str += grouping_log_str trace_structured( "artifact", metadata_fn=lambda: { @@ -850,7 +750,8 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - return new_snodes, stats + assert len(flatten_snodes) == n + return flatten_snodes, stats def sink_waits_iterative( @@ -876,9 +777,7 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): - outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" - ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" - detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" + detail = f" ({snode.node.python_kernel_name})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1453,7 +1352,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) + WeakDep(o.get_name(), mutating_buf=mutating_buf) ) prev_ag_wait = wait_group_node @@ -1465,7 +1364,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) + WeakDep(o.get_name(), mutating_buf=mutating_buf) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 8a374f5bab35c..9de52061c6489 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -342,12 +342,6 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str - # WeakDep's are also used to add dependencies to prevent some specific reordering, - # E.g. collectives global ordering. - # But if other pass guarantees proper ordering by its logic, - # This additional "fake" deps will be holding optimizations. - # This flag is used to identify those additional deps. - is_fake: bool = False @property def index(self) -> sympy.Expr: @@ -358,7 +352,7 @@ def get_numel(self) -> sympy.Expr: def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: - return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) + return WeakDep(renames[self.name], self.mutating_buf) return self def numbytes_hint(self) -> int: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a9aa28bb47508..d22d67cecff21 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2346,7 +2346,7 @@ def is_collective( from . import ir - ret = ( + return ( isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel) and (op is None or node.op_overload is op) @@ -2373,7 +2373,6 @@ def is_collective( ) ) ) - return ret def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: From 26807dcf277feb2d99ab88d7b6da526488baea93 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 16:35:55 +0000 Subject: [PATCH 052/457] Revert "[PT2][fusion] ban fusions with large accumulated reads (#157563)" This reverts commit c062550a3598d27c2d6572db7c0f4ff90a84cc84. Reverted https://github.com/pytorch/pytorch/pull/157563 on behalf of https://github.com/clee2000 due to broke test_linear_and_cel on main https://hud.pytorch.org/pytorch/pytorch/commit/c062550a3598d27c2d6572db7c0f4ff90a84cc84, caused OOM? Also broken on PR, Dr. CI classification is wrong (claims the test is disabled by an issue but the issue is for a different test). Also I'm pretty sure the expected results json is supposed to have a ton of empty lines, its to prevent merge conflicts, I will add it to the linter ([comment](https://github.com/pytorch/pytorch/pull/157563#issuecomment-3074355331)) --- .../pr_time_benchmarks/expected_results.csv | 86 ++++++++++++++++--- test/inductor/test_memory.py | 51 ----------- test/inductor/test_online_softmax.py | 8 +- torch/_inductor/choices.py | 4 - torch/_inductor/config.py | 1 - torch/_inductor/graph.py | 21 ----- torch/_inductor/ir.py | 11 --- torch/_inductor/memory.py | 13 ++- torch/_inductor/scheduler.py | 29 ++++--- 9 files changed, 106 insertions(+), 118 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 9e5521f94b43e..edc9d0f73d161 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,23 +1,89 @@ -add_loop_eager,compile_time_instruction_count,2996000000,0.015 +add_loop_eager,compile_time_instruction_count,3017000000,0.015 + + + add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 -add_loop_inductor,compile_time_instruction_count,33090000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42660000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,29690000000,0.015 + + + +add_loop_inductor,compile_time_instruction_count,29490000000,0.015 + + + +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 + + + +add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 + + + basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18830000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17460000000,0.015 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11020000000,0.2 + + + +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 + + + +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 + + + +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 + + + update_hint_regression,compile_time_instruction_count,1673000000,0.02 + + + sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 -symint_sum,compile_time_instruction_count,3184000000,0.015 + + + +symint_sum,compile_time_instruction_count,3166000000,0.015 + + + symint_sum_loop,compile_time_instruction_count,4202000000,0.015 + + + aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 + + + aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 + + + aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 + + + aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 + + + aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 + + + aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 -mm_loop_inductor_gpu,compile_time_instruction_count,4365000000,0.015 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8184000000,0.015 + + + +mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 + + + +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 + + + basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 + + + basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 489ba4ffeb0df..eaff539f7a493 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -306,57 +306,6 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) - def test_fusion_acc_large_reads(self): - def f(x, y, z): - res = torch.zeros_like(x[0]) - for i in range(4): - temp = torch.matmul(x, y) + z - res = res + temp - return res - - N = 128 - x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - - # CASE 1: no restriction on the amount of accumulation - with config.patch({"realize_acc_reads_size_threshold": float("inf")}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") - .run(code) - ) - - # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) - # at most 12 / 4 = 3 reads can be accumulated during fusion - with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") - .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") - .run(code) - ) - - # CASE 3: no such fusion allowed - with config.patch({"realize_acc_reads_size_threshold": N**2}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf1, arg2_1,") - .check("triton_poi_fused_add_0.run(buf3, arg2_1,") - .check("triton_poi_fused_add_0.run(buf4, buf3,") - .check("triton_poi_fused_add_0.run(buf6, arg2_1,") - .check("triton_poi_fused_add_0.run(buf7, buf6,") - .check("triton_poi_fused_add_0.run(buf9, arg2_1,") - .check("triton_poi_fused_add_0.run(buf10, buf9,") - .run(code) - ) - if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 37959c241113f..798d86b0dd617 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -13,7 +13,6 @@ instantiate_parametrized_tests, IS_LINUX, parametrize, - serialTest, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA @@ -78,17 +77,12 @@ def f(x): out, source_codes = run_and_get_code(f, x) return source_codes[0] - @serialTest() def test_codegen_3pass_softmax_due_to_disable(self): - with inductor_config.patch( - online_softmax=False, - realize_acc_reads_size_threshold=float("inf"), - ): + with inductor_config.patch(online_softmax=False): wrapper_code = self.get_softmax_wrapper() self.assertEqual(wrapper_code.count("for r0_offset in"), 3) - @serialTest() @parametrize("V", [2048, 50304]) @parametrize("use_log_softmax", [False, True]) def test_codegen_online_softmax(self, use_log_softmax, V): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 9096ba6dd0393..b7bab02da5e4b 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,10 +365,6 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False - if scheduler.fusion_accumulate_large_reads(node1, node2): - WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") - return False - return True @staticmethod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5c7a53683db3b..2e189c102db34 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -574,7 +574,6 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 -realize_acc_reads_size_threshold = 3 * (1024**3) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ac299d5b0c2d0..e2cc101533f28 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,7 +123,6 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen - from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -486,9 +485,6 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() - # Cache for dep size hints to avoid expensive recomputation - self.dep_size_hint_cache: dict[Dep, int] = {} - def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -574,23 +570,6 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) - def get_dep_size_hint(self, dep: Dep) -> int: - """ - Get the size hint for a dependency with caching to avoid expensive recomputation. - """ - if dep not in self.dep_size_hint_cache: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.dep_size_hint_cache[dep] = res - return self.dep_size_hint_cache[dep] - def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d6dd82aa52f2d..1edbb214ae2ad 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7829,10 +7829,6 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): - """ - StorageBox allow in-place mutation of Tensors - """ - def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7882,17 +7878,10 @@ def realize_hint(self) -> None: ): self.realize() - def has_accumulated_enough_reads_by_size(self) -> bool: - return ( - sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) - > config.realize_acc_reads_size_threshold - ) - def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() - or self.has_accumulated_enough_reads_by_size() ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index d287208419a9f..5601bc4adcee4 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,8 +78,19 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ + # this function is copied from torch/_inductor/scheduler.py + # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - return V.graph.get_dep_size_hint(dep) + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + return res # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 34f15869085f0..5c7a16d25bc64 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,12 +2051,15 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ + __dep_size_hint_cache: dict[Dep, int] + def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() + self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3502,17 +3505,6 @@ def _find_single_user_inputs( return True return False - def fusion_accumulate_large_reads( - self, node1: BaseSchedulerNode, node2: BaseSchedulerNode - ) -> bool: - all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( - node1.read_writes.writes | node2.read_writes.writes - ) - return ( - sum(self.dep_size_hint(dep) for dep in all_reads) - > config.realize_acc_reads_size_threshold - ) - def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4018,7 +4010,20 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - return V.graph.get_dep_size_hint(dep) + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode From b7b1109f49f5d0bd6145ae47c5c7d7d18c5659b0 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 15 Jul 2025 17:46:39 +0000 Subject: [PATCH 053/457] Expose opt_einsum in torch.backends (#157740) Fixes the following issue: ``` :/tmp# python -c "import torch; print(torch.__version__)" 2.7.1+cu126 :/tmp# python -c "import torch; print(torch.backends.opt_einsum.is_available())" Traceback (most recent call last): File "", line 1, in AttributeError: module 'torch.backends' has no attribute 'opt_einsum' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157740 Approved by: https://github.com/Skylion007, https://github.com/benjaminglass1 --- torch/backends/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index de194b12d02c3..c5ab7640386af 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -136,5 +136,6 @@ def __init__(self, m, name): mps as mps, nnpack as nnpack, openmp as openmp, + opt_einsum as opt_einsum, quantized as quantized, ) From 243b12e5657a516d6e7b1a0a3f55851ce99bd4cb Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Tue, 15 Jul 2025 17:50:20 +0000 Subject: [PATCH 054/457] [Optimus] add einsum_to_pointwise_pass pattern (#155666) Summary: More context: https://docs.google.com/document/d/1ipiskqG13ZKNX1SGygB3QnHcSyXNQ8pACazPIcS4bnI/edit?tab=t.0 Test Plan: ### how to enable ``` torch._inductor.config.pre_grad_fusion_options={ "einsum_to_pointwise_pass": {}, }, ``` ### unit test ``` CUDA_VISIBLE_DEVICES=3 OC_CAUSE=1 buck2 test 'fbcode//mode/dev-nosan' //caffe2/test/inductor:kernel_optimization ``` Buck UI: https://www.internalfb.com/buck2/267263ff-6f5b-4fff-bfc0-d8f013440ba0 Test UI: https://www.internalfb.com/intern/testinfra/testrun/5629499820839168 Network: Up: 61KiB Down: 675KiB (reSessionID-fda8edfc-6eef-4bf0-b268-0f8d2e666571) Loading targets. Remaining 0/1 1 dirs read, 2310 targets declared Analyzing targets. Remaining 0/345 284 actions, 329 artifacts declared Executing actions. Remaining 0/18334 8.0s exec time total Command: test. Finished 6 local Time elapsed: 1:15.5s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 ### local reproduce baseline: | Metric | Value | |:----------------------|:------------| | Batch size | 4096 | | GPU type | H100 | | Latency | 196.06 ms | | Model size | 1205.21 MB | | Flops | 7671.30 G | | Flops/example | 1.87 G | | TFLOPS/sec | 39.13 | | MFU | 4.89% | | Activation/example | 1.51 MB | | CPU time total | 602.28 ms | | GPU time total | 798.60 ms | | Estimated avg BW | 234.62 GB/s | | Estimated avg BW util | 9.78% | Trace link: https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/efficient_module_suite/fused_attention_mlp.Jun_09_22_12_38_trace.json.gz&bucket=pyper_traces with the pattern: | Metric | Value | |:----------------------|:------------| | Batch size | 4096 | | GPU type | H100 | | Latency | 184.94 ms | | Model size | 1205.21 MB | | Flops | 7671.30 G | | Flops/example | 1.87 G | | TFLOPS/sec | 41.48 | | MFU | 5.18% | | Activation/example | 1.15 MB | | CPU time total | 562.44 ms | | GPU time total | 754.36 ms | | Estimated avg BW | 201.40 GB/s | | Estimated avg BW util | 8.39% | Trace link: https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/efficient_module_suite/fused_attention_mlp.Jun_10_22_03_34_trace.json.gz&bucket=pyper_traces ### E2E baseline: f713998364 with patter: Rollback Plan: Differential Revision: D76400889 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155666 Approved by: https://github.com/Yuzhen11 --- test/inductor/test_kernel_optimization.py | 92 +++++++++++++++++++++++ torch/_inductor/fx_passes/split_cat.py | 63 ++++++++++++++++ torch/_inductor/pattern_matcher.py | 7 +- 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 test/inductor/test_kernel_optimization.py diff --git a/test/inductor/test_kernel_optimization.py b/test/inductor/test_kernel_optimization.py new file mode 100644 index 0000000000000..aabc8e83a06d7 --- /dev/null +++ b/test/inductor/test_kernel_optimization.py @@ -0,0 +1,92 @@ +# Owner(s): ["module: inductor"] + +import torch +import torch._inductor +from torch._dynamo.utils import counters +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu + + +class TestEinsumtoPointwise(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, + input: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + input2: torch.Tensor, + weights2: torch.Tensor, + bias2: torch.Tensor, + ) -> torch.Tensor: + output = torch.functional.einsum("bni, nio -> bno", input, weights) + add1 = output.add(bias) + output2 = torch.functional.einsum("bni, bnio -> bno", input2, weights2) + add2 = output2 + bias2 + return add1 + add2 + + +class TestKernelOptimization(TestCase): + def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): + if len(set(ref_dict.keys())) != len(set(res_dict.keys())): + return False + for key1 in ref_dict.keys(): + key2 = "_orig_mod." + key1 + assert key2 in res_dict, f"{key1} does not exist in traced module" + if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): + return False + return True + + def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): + ref = module(*input) + res = traced(*input) + self.assertEqual(ref, res, rtol=rtol, atol=atol) + + def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): + ref_params = dict(module.named_parameters()) + res_params = dict(traced.named_parameters()) + self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) + + def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): + ref_grad = {key: param.grad for key, param in module.named_parameters()} + res_grad = {key: param.grad for key, param in traced.named_parameters()} + self.assertTrue( + self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) + ) + + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "einsum_to_pointwise_pass": {}, + }, + post_grad_fusion_options={}, + ) + def test_einsum_to_pointwise(self): + counters.clear() + module = TestEinsumtoPointwise().to(GPU_TYPE) + input = [ + torch.randn(4096, 9, 512, device=GPU_TYPE, requires_grad=True), + torch.randn(9, 512, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(9, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 160, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 160, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 96, device=GPU_TYPE, requires_grad=True), + ] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + ref.sum().backward() + res.sum().backward() + self.compare_pred(module, traced, input) + self.compare_parameters(module, traced) + self.compare_gradients(module, traced) + self.assertEqual( + counters["inductor"]["einsum_to_pointwise_pass"], + 1, + ) + counters.clear() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index ba0529a5fad91..098f69fd863e2 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -62,6 +62,7 @@ "split_stack_to_cats_pass", "unbind_stack_to_slices_pass", "move_reshape_out_of_split_stack_pass", + "einsum_to_pointwise_pass", ] post_grad_pass_names = [ @@ -2965,3 +2966,65 @@ def move_view_after_cat(match: Match, *args, **kwargs): view_node.meta.update(cat_node.meta) graph.erase_node(cat_node) counters["inductor"]["move_view_after_cat_aten_pass"] += 1 + + +def match_einsum_strings(s: str) -> bool: + """ + This function takes a string s as input, where s is in the format "3 letter string, + 4 letter string -> 3 letter string". + It checks if the strings match the rule and returns True if they do, False otherwise. + + The rule is: + - The three strings have the same first two characters. + - The first two strings have the same third character. + - The second and third strings have the same last character. + """ + + # Split the input string into parts + parts = s.replace("->", ",").split(",") + + # Strip leading/trailing whitespaces from each part + parts = [part.strip() for part in parts] + + # Check if we have exactly three parts + if len(parts) != 3: + return False + + # Extract the strings + s1, s2, s3 = parts + + # Check if the strings have the correct lengths + if len(s1) != 3 or len(s2) != 4 or len(s3) != 3: + return False + + # Check the rule + return s1[:2] == s2[:2] == s3[:2] and s1[2] == s2[2] and s2[3] == s3[2] + + +@register_graph_pattern( + CallFunctionVarArgs(torch.functional.einsum, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("einsum_to_pointwise_pass"), +) +def replace_einsum_to_pointwise(match: Match, *args, **kwargs): + def repl(input, weights): + return (input.unsqueeze(-1) * weights).sum(-2) + + def should_replace_einsum(einsum_node) -> bool: + equation = get_arg_value(einsum_node, 0) + users = einsum_node.users.keys() + # for now, we only consider the case of two operands + return ( + len(einsum_node.args) == 3 + and is_node_meta_valid(input) + and is_node_meta_valid(weights) + and any( + user.target == "add" or user.target == operator.add for user in users + ) + and match_einsum_strings(equation) + ) + + einsum_node = match.nodes[0] + input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) + if should_replace_einsum(einsum_node): + match.replace_by_example(repl, [input, weights]) + counters["inductor"]["einsum_to_pointwise_pass"] += 1 diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 1da31586b0a18..b13a058324d41 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -318,7 +318,12 @@ def record(node: torch.fx.Node, val: Any) -> None: ] else: - example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + example_vals = torch.fx.map_arg( + args, + lambda arg: arg.meta["val"] + if "val" in arg.meta + else arg.meta["example_value"], + ) replacement = trace_fn(replacement_fn, example_vals) if len(self.nodes) == 1: for n in replacement.graph.nodes: From b26da7741be37693ab1cd21115f3fca15b1cdb6b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 18:06:52 +0000 Subject: [PATCH 055/457] Revert "[CI] Fixes CI for CUDA Version > 12.9 (#157385)" This reverts commit 6c5227ba00a2904365af566c24b4681cd01a041c. Reverted https://github.com/pytorch/pytorch/pull/157385 on behalf of https://github.com/clee2000 due to broke some slow tests test_cpp_extensions_jit.py::TestCppExtensionJIT::test_jit_cuda_archflags [GH job link](https://github.com/pytorch/pytorch/actions/runs/16286465717/job/45986677885) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/6c5227ba00a2904365af566c24b4681cd01a041c) ([comment](https://github.com/pytorch/pytorch/pull/157385#issuecomment-3074737541)) --- test/test_cpp_extensions_jit.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 84f5923697a2d..d671e3f874c96 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -322,15 +322,12 @@ def test_jit_cuda_archflags(self): [f"{capability[0]}{capability[1]}" for capability in capabilities], None, ), + "Maxwell+Tegra;6.1": (["53", "61"], None), + "Volta": (["70"], ["70"]), } archflags["7.5+PTX"] = (["75"], ["75"]) - major, minor = map(int, torch.version.cuda.split(".")[:2]) - if major < 12 or (major == 12 and minor <= 9): - # Compute capability <= 7.0 is only supported up to CUDA 12.9 - archflags["Maxwell+Tegra;6.1"] = (["53", "61"], None) - archflags["Volta"] = ((["70"], ["70"]),) - archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) - if major < 12: + archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) + if int(torch.version.cuda.split(".")[0]) < 12: # CUDA 12 drops compute capability < 5.0 archflags["Pascal 3.5"] = (["35", "60", "61"], None) From f2ecf6145fde55baa8a91e27b6b3489172f0e639 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 18:17:43 +0000 Subject: [PATCH 056/457] Revert "Enable AcceleratorAllocatorConfig key check (#157908)" This reverts commit 65fcca4f8c97de82d35d51ad9b790d10433e9b91. Reverted https://github.com/pytorch/pytorch/pull/157908 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing internally per https://github.com/pytorch/pytorch/pull/157908#discussion_r2208204782 ([comment](https://github.com/pytorch/pytorch/pull/157908#issuecomment-3074833696)) --- c10/core/AllocatorConfig.cpp | 9 ---- c10/core/AllocatorConfig.h | 59 ++++++-------------------- c10/test/core/AllocatorConfig_test.cpp | 12 ++---- 3 files changed, 16 insertions(+), 64 deletions(-) diff --git a/c10/core/AllocatorConfig.cpp b/c10/core/AllocatorConfig.cpp index f0a3134fae68c..9ceb40ccf6d74 100644 --- a/c10/core/AllocatorConfig.cpp +++ b/c10/core/AllocatorConfig.cpp @@ -221,15 +221,6 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) { } else if (key == "pinned_use_background_threads") { i = parsePinnedUseBackgroundThreads(tokenizer, i); } else { - // If a device-specific configuration parser hook is registered, it will - // check if the key is unrecognized. - if (device_config_parser_hook_) { - TORCH_CHECK( - keys_.find(key) != keys_.end(), - "Unrecognized key '", - key, - "' in Accelerator allocator config."); - } i = tokenizer.skipKey(i); } diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h index eddaaa5ffc6cf..e19160ea5978e 100644 --- a/c10/core/AllocatorConfig.h +++ b/c10/core/AllocatorConfig.h @@ -7,7 +7,6 @@ #include #include #include -#include #include namespace c10::CachingAllocator { @@ -181,7 +180,7 @@ class C10_API AcceleratorAllocatorConfig { // Returns the vector of division factors used for rounding up allocation // sizes. These divisions apply to size intervals between 1MB and 64GB. - static const std::vector& roundup_power2_divisions() { + static std::vector roundup_power2_divisions() { return instance().roundup_power2_divisions_; } @@ -220,13 +219,6 @@ class C10_API AcceleratorAllocatorConfig { return instance().last_allocator_settings_; } - // Returns the set of valid keys for the allocator configuration. - // This set is used to validate the presence and correctness of keys in - // device-specific configuration parsers. - static const std::unordered_set& getKeys() { - return instance().keys_; - } - // Parses the environment variable `env` to update the allocator settings. // If the environment variable is not set, it does nothing. // The configuration string should be a comma-separated list of key-value @@ -235,24 +227,16 @@ class C10_API AcceleratorAllocatorConfig { // "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true" void parseArgs(const std::string& env); - // Registers a device-specific configuration parser hook and its key. This - // allows backends to parse additional device-specific configuration options - // from the environment variable. The hook should be a function that takes a - // string (the environment variable value) and parses it to set - // device-specific configuration options. The hook will be called when the - // environment variable is parsed. If a hook is already registered, it will be - // replaced with the new one. + // Registers a device-specific configuration parser hook. This allows + // backends to parse additional device-specific configuration options from the + // environment variable. The hook should be a function that takes a string + // (the environment variable value) and parses it to set device-specific + // configuration options. + // The hook will be called when the environment variable is parsed. + // If a hook is already registered, it will be replaced with the new one. void registerDeviceConfigParserHook( - std::function&& hook, - const std::unordered_set& keys) { + std::function hook) { device_config_parser_hook_ = std::move(hook); - for (auto& key : keys) { - TORCH_CHECK( - keys_.insert(key).second, - "Duplicated key '", - key, - "' found in device-specific configuration parser hook registration"); - } } // Calls the registered device-specific configuration parser hook with the @@ -325,17 +309,6 @@ class C10_API AcceleratorAllocatorConfig { // This allows backends (e.g., CUDA, XPU) to register a custom parser for // their own environment configuration extensions. std::function device_config_parser_hook_{nullptr}; - - // A set of valid configuration keys, including both common and - // device-specific options. This set is used to validate the presence and - // correctness of keys during parsing. - std::unordered_set keys_{ - "max_split_size_mb", - "max_non_split_rounding_mb", - "garbage_collection_threshold", - "roundup_power2_divisions", - "expandable_segments", - "pinned_use_background_threads"}; }; C10_API inline void setAllocatorSettings(const std::string& env) { @@ -349,22 +322,16 @@ C10_API inline std::string getAllocatorSettings() { struct DeviceConfigParserHookRegistry { explicit DeviceConfigParserHookRegistry( - std::function&& hook, - const std::unordered_set& keys) { + std::function hook) { AcceleratorAllocatorConfig::instance().registerDeviceConfigParserHook( - std::move(hook), keys); + std::move(hook)); } }; -// Assume each config parser has `parseArgs` and `getKeys` methods -#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \ +#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(hook) \ namespace { \ static at::CachingAllocator::DeviceConfigParserHookRegistry \ - g_device_config_parse_hook_registry_instance( \ - [](const std::string& env) { \ - parser_cls::instance().parseArgs(env); \ - }, \ - parser_cls::getKeys()); \ + g_device_config_parse_hook_registry_instance(hook); \ } } // namespace c10::CachingAllocator diff --git a/c10/test/core/AllocatorConfig_test.cpp b/c10/test/core/AllocatorConfig_test.cpp index c2c0e6261d7b6..c051cf4cd4a05 100644 --- a/c10/test/core/AllocatorConfig_test.cpp +++ b/c10/test/core/AllocatorConfig_test.cpp @@ -16,10 +16,6 @@ struct ExtendedAllocatorConfig { return instance().device_specific_option_; } - static const std::unordered_set& getKeys() { - return instance().keys_; - } - void parseArgs(const std::string& env) { // Parse device-specific options from the environment variable ConfigTokenizer tokenizer(env); @@ -41,10 +37,11 @@ struct ExtendedAllocatorConfig { private: // Device-specific option, e.g., memory limit for a specific device. std::atomic device_specific_option_{0}; - std::unordered_set keys_{"device_specific_option_mb"}; }; -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig) +REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK([](const std::string& env) { + ExtendedAllocatorConfig::instance().parseArgs(env); +}) TEST(AllocatorConfigTest, allocator_config_test) { std::string env = @@ -123,7 +120,4 @@ TEST(AllocatorConfigTest, allocator_config_test) { c10::CachingAllocator::setAllocatorSettings(env); EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false); - - env = "foo:123,bar:456"; - ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error); } From ea5f88dca62b996cc8d081b14435d3d4392e043e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 18:24:36 +0000 Subject: [PATCH 057/457] Revert "Deprecate overleap functions in CUDAAllocatorConfig, use AcceleratorAllocatorConfig instead (#156165)" This reverts commit e40ade5182233f548b25f2732effe3719d16e9ad. Reverted https://github.com/pytorch/pytorch/pull/156165 on behalf of https://github.com/huydhn due to Sorry for reverting your change but because https://github.com/pytorch/pytorch/pull/157908 has been reverted + this PR caused issue earlier, I think it is better to revert the whole stack and reland it from scratch to be sure ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3074897532)) --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 2 +- c10/cuda/CUDAAllocatorConfig.h | 19 ++------- c10/cuda/CUDACachingAllocator.cpp | 47 ++++++++++----------- c10/xpu/XPUCachingAllocator.cpp | 3 +- torch/csrc/cuda/Module.cpp | 5 ++- 5 files changed, 32 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index b5e5f84cde13f..6a80342e10240 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl } bool pinned_use_background_threads() override { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: + return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: pinned_use_background_threads(); } diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 6254f85cd5b86..f96ae5e56ba6c 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include @@ -18,13 +17,9 @@ enum class Expandable_Segments_Handle_Type : int { // Environment config parser class C10_CUDA_API CUDAAllocatorConfig { public: - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.") static size_t max_split_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.") static double garbage_collection_threshold() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: garbage_collection_threshold(); @@ -65,8 +60,6 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.") static bool pinned_use_background_threads() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: pinned_use_background_threads(); @@ -79,29 +72,25 @@ class C10_CUDA_API CUDAAllocatorConfig { return 128; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") + // This is used to round-up allocation size to nearest power of 2 divisions. + // More description below in function roundup_power2_next_division + // As an example, if we want 4 divisions between 2's power, this can be done + // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 static size_t roundup_power2_divisions(size_t size) { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(size); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static std::vector roundup_power2_divisions() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: roundup_power2_divisions(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.") static size_t max_non_split_rounding_size() { return c10::CachingAllocator::AcceleratorAllocatorConfig:: max_non_split_rounding_size(); } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.") static std::string last_allocator_settings() { return c10::CachingAllocator::getAllocatorSettings(); } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 5ae04bcd3f53c..ed6914c350599 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1226,7 +1226,7 @@ class DeviceCachingAllocator { DeviceCachingAllocator() : large_blocks(/*small=*/false), small_blocks(/*small=*/true) { stats.max_split_size = - static_cast(AcceleratorAllocatorConfig::max_split_size()); + static_cast(CUDAAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); } @@ -1351,8 +1351,7 @@ class DeviceCachingAllocator { // Do garbage collection if the flag is set. if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > - 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); } // Attempt allocate @@ -1604,7 +1603,7 @@ class DeviceCachingAllocator { stats.active_bytes[stat_type].increase(block->size); stats.requested_bytes[stat_type].increase(block->requested_size); }); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.increase(1); auto allocated_bytes_gauge = @@ -1655,7 +1654,7 @@ class DeviceCachingAllocator { block->pool->owner_MempoolId(), context ? context : block->context_when_allocated); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { @@ -2205,8 +2204,7 @@ class DeviceCachingAllocator { if (size < kMinBlockSize) { return kMinBlockSize; } else { - auto divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(size); + auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size); if (divisions > 1 && size > (kMinBlockSize * divisions)) { return roundup_power2_next_division(size, divisions); } else { @@ -2696,7 +2694,7 @@ class DeviceCachingAllocator { if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) { return remaining >= kMinBlockSize; } else { - return (size < AcceleratorAllocatorConfig::max_split_size()) && + return (size < CUDAAllocatorConfig::max_split_size()) && (remaining > kSmallSize); } } @@ -2716,7 +2714,7 @@ class DeviceCachingAllocator { if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; } @@ -2758,13 +2756,13 @@ class DeviceCachingAllocator { } // Do not return an oversized block for a large request - if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size())) + if ((p.size() < CUDAAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size())) return false; // Allow oversized block size to be rounded up but within a limit - if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) && + if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && ((*it)->size >= - p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size())) + p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); @@ -2787,7 +2785,7 @@ class DeviceCachingAllocator { // therefore should be of less overheads. size_t gc_threshold = static_cast( - AcceleratorAllocatorConfig::garbage_collection_threshold() * + CUDAAllocatorConfig::garbage_collection_threshold() * static_cast(allowed_memory_maximum)); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { @@ -2935,7 +2933,7 @@ class DeviceCachingAllocator { stats.segment[stat_type].increase(1); stats.reserved_bytes[stat_type].increase(size); }); - if (size >= AcceleratorAllocatorConfig::max_split_size()) + if (size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.increase(1); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2964,7 +2962,7 @@ class DeviceCachingAllocator { bool release_available_cached_blocks( const AllocParams& p, const std::shared_ptr& context) { - if (AcceleratorAllocatorConfig::max_split_size() == + if (CUDAAllocatorConfig::max_split_size() == std::numeric_limits::max()) return false; BlockPool& pool = *p.pool; @@ -2972,8 +2970,8 @@ class DeviceCachingAllocator { // because of std::unique_ptr, block cannot be trivially copied // Use constructor for search key. Block key(p.search_key.device, p.search_key.stream, p.search_key.size); - key.size = (key.size < AcceleratorAllocatorConfig::max_split_size()) - ? AcceleratorAllocatorConfig::max_split_size() + key.size = (key.size < CUDAAllocatorConfig::max_split_size()) + ? CUDAAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); if (it == pool.blocks.end() || (*it)->stream != p.stream() || @@ -2986,7 +2984,7 @@ class DeviceCachingAllocator { --it; // Back up one item. Now on the largest block for the correct // stream while ((totalReleased < key.size) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; bool is_first = cur == pool.blocks.begin(); @@ -3111,7 +3109,7 @@ class DeviceCachingAllocator { stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; @@ -3738,8 +3736,8 @@ class NativeCachingAllocator : public CUDAAllocator { auto& md = result.config_metadata; md.garbage_collection_threshold = - AcceleratorAllocatorConfig::garbage_collection_threshold(); - md.max_split_size = AcceleratorAllocatorConfig::max_split_size(); + CUDAAllocatorConfig::garbage_collection_threshold(); + md.max_split_size = CUDAAllocatorConfig::max_split_size(); md.pinned_num_register_threads = CUDAAllocatorConfig::pinned_num_register_threads(); md.expandable_segments = CUDAAllocatorConfig::expandable_segments(); @@ -3747,10 +3745,9 @@ class NativeCachingAllocator : public CUDAAllocator { CUDAAllocatorConfig::release_lock_on_cudamalloc(); md.pinned_use_host_register = CUDAAllocatorConfig::pinned_use_cuda_host_register(); - md.last_allocator_settings = - AcceleratorAllocatorConfig::last_allocator_settings(); + md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings(); md.roundup_power2_divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(); + CUDAAllocatorConfig::roundup_power2_divisions(); return result; } diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b4..543b48f081135 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -21,6 +20,8 @@ constexpr size_t kMinBlockSize = 512; constexpr size_t kSmallSize = 1048576; // "small" allocations are packed in 2 MiB blocks constexpr size_t kSmallBuffer = 2097152; +// "large" allocations may be packed in 20 MiB blocks +constexpr size_t kLargeBuffer = 20971520; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kMinLargeAlloc = 10485760; // round up large allocations to 2 MiB diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index ead46337ff090..b44ce311ecd92 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -426,7 +426,8 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings( PyObject* _unused, PyObject* env) { HANDLE_TH_ERRORS - c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env)); + c10::cuda::CUDACachingAllocator::setAllocatorSettings( + THPUtils_unpackString(env)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } From 41971335c98b0881e0784085096eceace575d563 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 18:24:36 +0000 Subject: [PATCH 058/457] Revert "Refactor CUDAAllocatorConfig to reuse AcceleratorAllocatorConfig (#150312)" This reverts commit e241a07e6b88aa49d604803bc5a6562f0d9f94d2. Reverted https://github.com/pytorch/pytorch/pull/150312 on behalf of https://github.com/huydhn due to Sorry for reverting your change but because https://github.com/pytorch/pytorch/pull/157908 has been reverted + this PR caused issue earlier, I think it is better to revert the whole stack and reland it from scratch to be sure ([comment](https://github.com/pytorch/pytorch/pull/150312#issuecomment-3074897532)) --- c10/cuda/CUDAAllocatorConfig.cpp | 469 ++++++++++++++++++++++++------ c10/cuda/CUDAAllocatorConfig.h | 130 ++++----- c10/cuda/CUDACachingAllocator.cpp | 50 +++- c10/cuda/CUDACachingAllocator.h | 4 +- 4 files changed, 495 insertions(+), 158 deletions(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 49fa2e1e95ed3..d2efb8c593e44 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -1,119 +1,389 @@ #include +#include +#include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif -#include - namespace c10::cuda::CUDACachingAllocator { -size_t CUDAAllocatorConfig::parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, +constexpr size_t kRoundUpPowerOfTwoIntervals = 16; + +CUDAAllocatorConfig::CUDAAllocatorConfig() + : m_max_split_size(std::numeric_limits::max()), + m_max_non_split_rounding_size(kLargeBuffer), + m_garbage_collection_threshold(0), + m_pinned_num_register_threads(1), + m_expandable_segments(false), +#if CUDA_VERSION >= 12030 + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::UNSPECIFIED), +#else + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::POSIX_FD), +#endif + m_release_lock_on_cudamalloc(false), + m_pinned_use_cuda_host_register(false), + m_pinned_use_background_threads(false) { + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); +} + +size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) { + size_t log_size = (63 - llvm::countLeadingZeros(size)); + + // Our intervals start at 1MB and end at 64GB + const size_t interval_start = + 63 - llvm::countLeadingZeros(static_cast(1048576)); + const size_t interval_end = + 63 - llvm::countLeadingZeros(static_cast(68719476736)); + TORCH_CHECK( + (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), + "kRoundUpPowerOfTwoIntervals mismatch"); + + int index = static_cast(log_size) - static_cast(interval_start); + + index = std::max(0, index); + index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); + return instance().m_roundup_power2_divisions[index]; +} + +void CUDAAllocatorConfig::lexArgs( + const std::string& env, + std::vector& config) { + std::vector buf; + + for (char ch : env) { + if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + buf.clear(); + } + config.emplace_back(1, ch); + } else if (ch != ' ') { + buf.emplace_back(ch); + } + } + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + } +} + +void CUDAAllocatorConfig::consumeToken( + const std::vector& config, + size_t i, + const char c) { + TORCH_CHECK( + i < config.size() && config[i] == std::string(1, c), + "Error parsing CachingAllocator settings, expected ", + c, + ""); +} + +size_t CUDAAllocatorConfig::parseMaxSplitSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_split_size_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_split_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_non_split_rounding_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_non_split_rounding_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + double val1 = stod(config[i]); + TORCH_CHECK( + val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); + TORCH_CHECK( + val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); + m_garbage_collection_threshold = val1; + } else { + TORCH_CHECK( + false, "Error, expecting garbage_collection_threshold value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( + const std::vector& config, size_t i) { + consumeToken(config, ++i, ':'); + bool first_value = true; + + if (++i < config.size()) { + if (std::string_view(config[i]) == "[") { + size_t last_index = 0; + // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) + while (++i < config.size() && std::string_view(config[i]) != "]") { + const std::string& val1 = config[i]; + size_t val2 = 0; + + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + val2 = stoi(config[i]); + } else { + TORCH_CHECK( + false, "Error parsing roundup_power2_divisions value", ""); + } + TORCH_CHECK( + val2 == 0 || llvm::isPowerOf2_64(val2), + "For roundups, the divisions has to be power of 2 or 0 to disable roundup ", + ""); + + if (std::string_view(val1) == ">") { + std::fill( + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + last_index)), + m_roundup_power2_divisions.end(), + val2); + } else { + size_t val1_long = stoul(val1); + TORCH_CHECK( + llvm::isPowerOf2_64(val1_long), + "For roundups, the intervals have to be power of 2 ", + ""); + + size_t index = 63 - llvm::countLeadingZeros(val1_long); + index = std::max((size_t)0, index); + index = std::min(index, m_roundup_power2_divisions.size() - 1); + + if (first_value) { + std::fill( + m_roundup_power2_divisions.begin(), + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + index)), + val2); + first_value = false; + } + if (index < m_roundup_power2_divisions.size()) { + m_roundup_power2_divisions[index] = val2; + } + last_index = index; + } + + if (std::string_view(config[i + 1]) != "]") { + consumeToken(config, ++i, ','); + } + } + } else { // Keep this for backwards compatibility + size_t val1 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val1), + "For roundups, the divisions has to be power of 2 ", + ""); + std::fill( + m_roundup_power2_divisions.begin(), + m_roundup_power2_divisions.end(), + val1); + } + } else { + TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync) { // For ease of maintenance and understanding, the CUDA and ROCm // implementations of this function are separated. This avoids having many // #ifdef's throughout. +#ifdef USE_ROCM // Ease burden on ROCm users by allowing either cuda or hip tokens. // cuda token is broken up to prevent hipify matching it. #define PYTORCH_TOKEN1 \ "cud" \ "aMallocAsync" #define PYTORCH_TOKEN2 "hipMallocAsync" - tokenizer.checkToken(++i, ":"); - i++; // Move to the value after the colon - TORCH_CHECK( - ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) || - (tokenizer[i] == PYTORCH_TOKEN2)), - "Unknown allocator backend, " - "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); - if (m_is_allocator_loaded) { - bool aync_allocator_at_runtime = (tokenizer[i] != "native"); + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - aync_allocator_at_runtime == m_use_async_allocator, - "Allocator async backend parsed at runtime != allocator async backend parsed at load time, ", - aync_allocator_at_runtime, + ((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) || + (config[i] == PYTORCH_TOKEN2)), + "Unknown allocator backend, " + "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); + used_cudaMallocAsync = + (config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2); + TORCH_INTERNAL_ASSERT( + config[i] == get()->name() || + (config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time, ", + config[i], " != ", - m_use_async_allocator); + get()->name()); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } - m_use_async_allocator = - (tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2); - // CUDA allocator is always loaded at the start of the program - m_is_allocator_loaded = true; - -#if defined(CUDA_VERSION) - if (m_use_async_allocator) { -#if CUDA_VERSION >= 11040 - int version = 0; - C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + return i; +#undef PYTORCH_TOKEN1 +#undef PYTORCH_TOKEN2 +#else // USE_ROCM + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - version >= 11040, - "backend:cudaMallocAsync requires CUDA runtime " - "11.4 or newer, but cudaDriverGetVersion returned ", - version); + ((config[i] == "native") || (config[i] == "cudaMallocAsync")), + "Unknown allocator backend, " + "options are native and cudaMallocAsync"); + used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); + if (used_cudaMallocAsync) { +#if CUDA_VERSION >= 11040 + int version = 0; + C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + TORCH_CHECK( + version >= 11040, + "backend:cudaMallocAsync requires CUDA runtime " + "11.4 or newer, but cudaDriverGetVersion returned ", + version); #else - TORCH_CHECK( - false, - "backend:cudaMallocAsync requires PyTorch to be built with " - "CUDA 11.4 or newer, but CUDA_VERSION is ", - CUDA_VERSION); + TORCH_CHECK( + false, + "backend:cudaMallocAsync requires PyTorch to be built with " + "CUDA 11.4 or newer, but CUDA_VERSION is ", + CUDA_VERSION); #endif + } + TORCH_INTERNAL_ASSERT( + config[i] == get()->name(), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time"); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } -#endif - return i; -#undef PYTORCH_TOKEN1 -#undef PYTORCH_TOKEN2 +#endif // USE_ROCM } -void CUDAAllocatorConfig::parseArgs(const std::string& env) { +void CUDAAllocatorConfig::parseArgs(const std::optional& env) { // If empty, set the default values + m_max_split_size = std::numeric_limits::max(); + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); + m_garbage_collection_threshold = 0; + bool used_cudaMallocAsync = false; bool used_native_specific_option = false; - c10::CachingAllocator::ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "backend") { - i = parseAllocatorConfig(tokenizer, i); + if (!env.has_value()) { + return; + } + { + std::lock_guard lock(m_last_allocator_settings_mutex); + m_last_allocator_settings = env.value(); + } + + std::vector config; + lexArgs(env.value(), config); + + for (size_t i = 0; i < config.size(); i++) { + std::string_view config_item_view(config[i]); + if (config_item_view == "max_split_size_mb") { + i = parseMaxSplitSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "max_non_split_rounding_mb") { + i = parseMaxNonSplitRoundingSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "garbage_collection_threshold") { + i = parseGarbageCollectionThreshold(config, i); + used_native_specific_option = true; + } else if (config_item_view == "roundup_power2_divisions") { + i = parseRoundUpPower2Divisions(config, i); + used_native_specific_option = true; + } else if (config_item_view == "backend") { + i = parseAllocatorConfig(config, i, used_cudaMallocAsync); + } else if (config_item_view == "expandable_segments") { + used_native_specific_option = true; + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for expandable_segments"); + config_item_view = config[i]; + m_expandable_segments = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "release_lock_on_hipmalloc" || - key == + config_item_view == "release_lock_on_hipmalloc" || + config_item_view == "release_lock_on_c" "udamalloc") { used_native_specific_option = true; - tokenizer.checkToken(++i, ":"); - m_release_lock_on_cudamalloc = tokenizer.toBool(++i); + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for release_lock_on_cudamalloc"); + config_item_view = config[i]; + m_release_lock_on_cudamalloc = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "pinned_use_hip_host_register" || - key == + config_item_view == "pinned_use_hip_host_register" || + config_item_view == "pinned_use_c" "uda_host_register") { - i = parsePinnedUseCudaHostRegister(tokenizer, i); + i = parsePinnedUseCudaHostRegister(config, i); used_native_specific_option = true; - } else if (key == "pinned_num_register_threads") { - i = parsePinnedNumRegisterThreads(tokenizer, i); + } else if (config_item_view == "pinned_num_register_threads") { + i = parsePinnedNumRegisterThreads(config, i); + used_native_specific_option = true; + } else if (config_item_view == "pinned_use_background_threads") { + i = parsePinnedUseBackgroundThreads(config, i); used_native_specific_option = true; } else { - const auto& keys = - c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); TORCH_CHECK( - keys.find(key) != keys.end(), - "Unrecognized key '", - key, - "' in Accelerator allocator config."); - i = tokenizer.skipKey(i); + false, "Unrecognized CachingAllocator option: ", config_item_view); } - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); + if (i + 1 < config.size()) { + consumeToken(config, ++i, ','); } } - if (m_use_async_allocator && used_native_specific_option) { + if (used_cudaMallocAsync && used_native_specific_option) { TORCH_WARN( "backend:cudaMallocAsync ignores max_split_size_mb," "roundup_power2_divisions, and garbage_collect_threshold."); @@ -121,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - m_pinned_use_cuda_host_register = tokenizer.toBool(++i); - + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_cuda_host_register"); + m_pinned_use_cuda_host_register = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_cuda_host_register value", ""); + } return i; } size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - size_t val2 = tokenizer.toSizeT(++i); - TORCH_CHECK( - llvm::isPowerOf2_64(val2), - "Number of register threads has to be power of 2 ", - ""); - auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); - TORCH_CHECK( - val2 <= maxThreads, - "Number of register threads should be less than or equal to " + - std::to_string(maxThreads), - ""); - m_pinned_num_register_threads = val2; + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val2 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; + } else { + TORCH_CHECK( + false, "Error, expecting pinned_num_register_threads value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_background_threads"); + m_pinned_use_background_threads = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_background_threads value", ""); + } return i; } -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig) +// General caching allocator utilities +void setAllocatorSettings(const std::string& env) { + CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); +} } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index f96ae5e56ba6c..fda3cc02e5d0a 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -1,11 +1,16 @@ #pragma once -#include -#include #include #include #include +#include +#include +#include +#include +#include +#include + namespace c10::cuda::CUDACachingAllocator { enum class Expandable_Segments_Handle_Type : int { @@ -18,23 +23,20 @@ enum class Expandable_Segments_Handle_Type : int { class C10_CUDA_API CUDAAllocatorConfig { public: static size_t max_split_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); + return instance().m_max_split_size; } static double garbage_collection_threshold() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - garbage_collection_threshold(); + return instance().m_garbage_collection_threshold; } static bool expandable_segments() { - bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: - use_expandable_segments(); #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED - if (enabled) { + if (instance().m_expandable_segments) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } return false; #else - return enabled; + return instance().m_expandable_segments; #endif } @@ -61,8 +63,7 @@ class C10_CUDA_API CUDAAllocatorConfig { } static bool pinned_use_background_threads() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - pinned_use_background_threads(); + return instance().m_pinned_use_background_threads; } static size_t pinned_max_register_threads() { @@ -76,97 +77,88 @@ class C10_CUDA_API CUDAAllocatorConfig { // More description below in function roundup_power2_next_division // As an example, if we want 4 divisions between 2's power, this can be done // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 - static size_t roundup_power2_divisions(size_t size) { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(size); - } + static size_t roundup_power2_divisions(size_t size); static std::vector roundup_power2_divisions() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(); + return instance().m_roundup_power2_divisions; } static size_t max_non_split_rounding_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - max_non_split_rounding_size(); + return instance().m_max_non_split_rounding_size; } static std::string last_allocator_settings() { - return c10::CachingAllocator::getAllocatorSettings(); - } - - static bool use_async_allocator() { - return instance().m_use_async_allocator; - } - - static const std::unordered_set& getKeys() { - return instance().keys_; + std::lock_guard lock( + instance().m_last_allocator_settings_mutex); + return instance().m_last_allocator_settings; } static CUDAAllocatorConfig& instance() { static CUDAAllocatorConfig* s_instance = ([]() { auto inst = new CUDAAllocatorConfig(); - auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF"); - if (!env.has_value()) { - // For backward compatibility, check for the old environment variable - // PYTORCH_CUDA_ALLOC_CONF. - env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); - } + auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); #ifdef USE_ROCM // convenience for ROCm users, allow alternative HIP token if (!env.has_value()) { env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); } #endif - if (env.has_value()) { - inst->parseArgs(env.value()); - } + inst->parseArgs(env); return inst; })(); return *s_instance; } - void parseArgs(const std::string& env); + void parseArgs(const std::optional& env); private: - CUDAAllocatorConfig() = default; - - size_t parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + CUDAAllocatorConfig(); + + static void lexArgs(const std::string& env, std::vector& config); + static void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync); size_t parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i); size_t parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, + size_t i); + size_t parsePinnedUseBackgroundThreads( + const std::vector& config, size_t i); - std::atomic m_pinned_num_register_threads{1}; - std::atomic m_expandable_segments_handle_type -#if CUDA_VERSION >= 12030 - {Expandable_Segments_Handle_Type::UNSPECIFIED}; -#else - {Expandable_Segments_Handle_Type::POSIX_FD}; -#endif - std::atomic m_release_lock_on_cudamalloc{false}; - std::atomic m_pinned_use_cuda_host_register{false}; - std::atomic m_use_async_allocator{false}; - std::atomic m_is_allocator_loaded{false}; - std::unordered_set keys_{ - "backend", - // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues - // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_cud" - "amalloc", - "pinned_use_cud" - "a_host_register", - // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_hipmalloc", - "pinned_use_hip_host_register", - "pinned_num_register_threads"}; + std::atomic m_max_split_size; + std::atomic m_max_non_split_rounding_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; + std::atomic m_expandable_segments; + std::atomic + m_expandable_segments_handle_type; + std::atomic m_release_lock_on_cudamalloc; + std::atomic m_pinned_use_cuda_host_register; + std::atomic m_pinned_use_background_threads; + std::string m_last_allocator_settings; + std::mutex m_last_allocator_settings_mutex; }; -// Keep this for backwards compatibility -using c10::CachingAllocator::setAllocatorSettings; +// General caching allocator utilities +C10_CUDA_API void setAllocatorSettings(const std::string& env); } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index ed6914c350599..4d58c11c5c9bc 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator { using namespace c10::CachingAllocator; using namespace c10::CachingDeviceAllocator; +// Included here as this is externally used in CUDAAllocatorConfig +const size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks + namespace Native { // @@ -4125,10 +4130,49 @@ CUDAAllocator* allocator(); } // namespace CudaMallocAsync struct BackendStaticInitializer { + // Parses env for backend at load time, duplicating some logic from + // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at + // runtime). Defers verbose exceptions and error checks, including Cuda + // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this + // works, maybe we should move all of CUDAAllocatorConfig here? CUDAAllocator* parseEnvForBackend() { - // If the environment variable is set, we use the CudaMallocAsync allocator. - if (CUDAAllocatorConfig::use_async_allocator()) { - return CudaMallocAsync::allocator(); + auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (!val.has_value()) { + val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); + } +#endif + if (val.has_value()) { + const std::string& config = val.value(); + + std::regex exp("[\\s,]+"); + std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); + std::sregex_token_iterator end; + std::vector options(it, end); + + for (auto option : options) { + std::regex exp2("[:]+"); + std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); + std::sregex_token_iterator end2; + std::vector kv(it2, end2); + if (kv.size() >= 2) { + if (kv[0] == "backend") { +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (kv[1] == + "cud" + "aMallocAsync" || + kv[1] == "hipMallocAsync") +#else + if (kv[1] == "cudaMallocAsync") +#endif + return CudaMallocAsync::allocator(); + if (kv[1] == "native") + return &Native::allocator; + } + } + } } return &Native::allocator; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe22827..a6fa61110d675 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator { // Preserved only for BC reasons // NOLINTNEXTLINE(misc-unused-using-decls) -using c10::CachingAllocator::kLargeBuffer; using c10::CachingDeviceAllocator::DeviceStats; +extern const size_t kLargeBuffer; + typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a From 8c3f206457a1b9d75bc95a6c30a101135fcee329 Mon Sep 17 00:00:00 2001 From: Robert Hardwick Date: Tue, 15 Jul 2025 18:26:34 +0000 Subject: [PATCH 059/457] Fix AArch64 segfaults by disabling strict-aliasing in GridSamplerKernel for GCC 12 and above (#158117) This PR disables `strict-aliasing` GCC C++ optimization flag on all AArch64 cpus for GCC versions 12 and above. Pull Request #152825 upgraded gcc version from 11 to 13 in manywheel which caused several segmentation faults in unit tests ( not visible in CI workflows because the jammy gcc version has not been updated yet ). We Identified the problem also exists in GCC12 hence the ` __GNUC__ >= 12` Fixes #157626 fixes these tests failures when pytorch is built in GCC12 and above ``` test_ops.py::TestCommonCPU::test_noncontiguous_samples_grid_sampler_2d_cpu_float32 Fatal Python error: Segmentation fault test_ops.py::TestCommonCPU::test_dtypes_grid_sampler_2d_cpu Fatal Python error: Segmentation fault test_ops.py::TestMathBitsCPU::test_neg_view_nn_functional_grid_sample_cpu_float64 free(): invalid next size (fast) test_ops.py::TestCompositeComplianceCPU::test_backward_grid_sampler_2d_cpu_float32 Fatal Python error: Segmentation fault test_ops.py::TestCommonCPU::test_dtypes_nn_functional_grid_sample_cpu Fatal Python error: Segmentation fault ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158117 Approved by: https://github.com/malfet --- aten/src/ATen/native/cpu/GridSamplerKernel.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 521a65c7cd948..9450b7eca9b37 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -14,6 +14,12 @@ namespace at::native { namespace { +// fixes segfaults for GCC >= 12 on some AArch64 cpus https://github.com/pytorch/pytorch/issues/157626 +#if defined(__GNUC__) && __GNUC__ >= 12 && defined(__aarch64__) +#pragma GCC push_options +#pragma GCC optimize ("no-strict-aliasing") +#endif + /** NOTE [ Grid Sample CPU Kernels ] * * Implementation of vectorized grid sample CPU kernels is divided into three @@ -1014,6 +1020,10 @@ struct ApplyGridSample= 12 && defined(__aarch64__) +#pragma GCC pop_options +#endif + // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~ // Function to apply a vectorized function on a grid slice tensor (without batch // dimension). From 46915b13614dbac90724d0f1802b8e0db037c9e4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Jul 2025 18:40:59 +0000 Subject: [PATCH 060/457] Revert "Introduce AcceleratorAllocatorConfig as the common class (#149601)" This reverts commit 1e8e9f745e43fa38bbfc7b67b30bc66c0e7ebbd6. Reverted https://github.com/pytorch/pytorch/pull/149601 on behalf of https://github.com/huydhn due to See https://github.com/pytorch/pytorch/pull/149601#discussion_r2208325379 ([comment](https://github.com/pytorch/pytorch/pull/149601#issuecomment-3074965720)) --- c10/core/AllocatorConfig.cpp | 233 ----------------- c10/core/AllocatorConfig.h | 337 ------------------------- c10/test/core/AllocatorConfig_test.cpp | 123 --------- 3 files changed, 693 deletions(-) delete mode 100644 c10/core/AllocatorConfig.cpp delete mode 100644 c10/core/AllocatorConfig.h delete mode 100644 c10/test/core/AllocatorConfig_test.cpp diff --git a/c10/core/AllocatorConfig.cpp b/c10/core/AllocatorConfig.cpp deleted file mode 100644 index 9ceb40ccf6d74..0000000000000 --- a/c10/core/AllocatorConfig.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include -#include -#include -#include - -namespace c10::CachingAllocator { - -namespace { -constexpr size_t kRoundUpPowerOfTwoIntervals = 16; -constexpr size_t kMB = 1024 * 1024ul; -constexpr size_t kRoundUpPowerOfTwoStart = 1 * kMB; // 1MB -constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB -} // anonymous namespace - -AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() { - static AcceleratorAllocatorConfig instance; -#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \ - auto env##_name = c10::utils::get_env(#env); \ - if (env##_name.has_value()) { \ - if (deprecated) { \ - TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \ - } \ - instance.parseArgs(env##_name.value()); \ - return true; \ - } - static bool env_flag [[maybe_unused]] = []() { - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false) - // Keep this for backwards compatibility - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true) - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true) - return false; - }(); -#undef C10_ALLOCATOR_CONFIG_PARSE_ENV - return instance; -} - -AcceleratorAllocatorConfig::AcceleratorAllocatorConfig() { - roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0); -} - -size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) { - size_t log_size = (63 - llvm::countLeadingZeros(size)); - - // Our intervals start at 1MB and end at 64GB - const size_t interval_start = - 63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart); - const size_t interval_end = - 63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd); - TORCH_CHECK( - interval_end - interval_start == kRoundUpPowerOfTwoIntervals, - "kRoundUpPowerOfTwoIntervals mismatch"); - - size_t index = - (log_size > interval_start) ? (log_size - interval_start) : 0ul; - index = std::min(index, kRoundUpPowerOfTwoIntervals - 1); - return instance().roundup_power2_divisions_[index]; -} - -size_t AcceleratorAllocatorConfig::parseMaxSplitSize( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB; - constexpr size_t max_allowed_split_size_mb = - std::numeric_limits::max() / kMB; - - size_t val_env = tokenizer.toSizeT(++i); - TORCH_CHECK( - val_env >= min_allowed_split_size_mb, - "CachingAllocator option max_split_size_mb too small, must be >= ", - min_allowed_split_size_mb); - val_env = std::min(val_env, max_allowed_split_size_mb); - max_split_size_ = val_env * kMB; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB; - constexpr size_t max_allowed_split_size_mb = - std::numeric_limits::max() / kMB; - - size_t val_env = tokenizer.toSizeT(++i); - TORCH_CHECK( - val_env >= min_allowed_split_size_mb, - "CachingAllocator option max_non_split_rounding_mb too small, must be >= ", - min_allowed_split_size_mb); - val_env = std::min(val_env, max_allowed_split_size_mb); - max_non_split_rounding_size_ = val_env * kMB; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - double val_env = tokenizer.toDouble(++i); - TORCH_CHECK( - val_env > 0 && val_env < 1.0, - "garbage_collect_threshold is invalid, set it in (0.0, 1.0)"); - garbage_collection_threshold_ = val_env; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - bool first_value = true; - - if (tokenizer[++i] == "[") { - size_t last_index = 0; - // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) - while (++i < tokenizer.size() && tokenizer[i] != "]") { - size_t value_index = i; - tokenizer.checkToken(++i, ":"); - size_t value = tokenizer.toSizeT(++i); - TORCH_CHECK( - value == 0 || llvm::isPowerOf2_64(value), - "For roundups, the divisions has to be power of 2 or 0 to disable roundup "); - - if (tokenizer[value_index] == ">") { - std::fill( - std::next( - roundup_power2_divisions_.begin(), - static_cast::difference_type>( - last_index + 1)), - roundup_power2_divisions_.end(), - value); - } else { - size_t boundary = tokenizer.toSizeT(value_index); - TORCH_CHECK( - llvm::isPowerOf2_64(boundary), - "For roundups, the intervals have to be power of 2 "); - - size_t index = 63 - llvm::countLeadingZeros(boundary); - index = - std::clamp(index, size_t{0}, roundup_power2_divisions_.size() - 1); - - if (first_value) { - std::fill( - roundup_power2_divisions_.begin(), - std::next( - roundup_power2_divisions_.begin(), - static_cast::difference_type>(index)), - value); - first_value = false; - } - roundup_power2_divisions_[index] = value; - last_index = index; - } - - if (tokenizer[i + 1] != "]") { - tokenizer.checkToken(++i, ","); - } - } - TORCH_INTERNAL_ASSERT( - i < tokenizer.size(), - "Expected closing bracket ']' in ConfigTokenizer but reached end of config"); - } else { // Keep this for backwards compatibility - size_t value = tokenizer.toSizeT(i); - TORCH_CHECK( - llvm::isPowerOf2_64(value), - "For roundups, the divisions has to be power of 2 "); - std::fill( - roundup_power2_divisions_.begin(), - roundup_power2_divisions_.end(), - value); - } - return i; -} - -size_t AcceleratorAllocatorConfig::parseExpandableSegments( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - use_expandable_segments_ = tokenizer.toBool(++i); - - return i; -} - -size_t AcceleratorAllocatorConfig::parsePinnedUseBackgroundThreads( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - pinned_use_background_threads_ = tokenizer.toBool(++i); - - return i; -} - -void AcceleratorAllocatorConfig::parseArgs(const std::string& env) { - // The following option will be reset to its default value if not explicitly - // set each time. - max_split_size_ = std::numeric_limits::max(); - roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0); - garbage_collection_threshold_ = 0; - - { - std::lock_guard lock(last_allocator_settings_mutex_); - last_allocator_settings_ = env; - } - - ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "max_split_size_mb") { - i = parseMaxSplitSize(tokenizer, i); - } else if (key == "max_non_split_rounding_mb") { - i = parseMaxNonSplitRoundingSize(tokenizer, i); - } else if (key == "garbage_collection_threshold") { - i = parseGarbageCollectionThreshold(tokenizer, i); - } else if (key == "roundup_power2_divisions") { - i = parseRoundUpPower2Divisions(tokenizer, i); - } else if (key == "expandable_segments") { - i = parseExpandableSegments(tokenizer, i); - } else if (key == "pinned_use_background_threads") { - i = parsePinnedUseBackgroundThreads(tokenizer, i); - } else { - i = tokenizer.skipKey(i); - } - - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); - } - } -} - -} // namespace c10::CachingAllocator diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h deleted file mode 100644 index e19160ea5978e..0000000000000 --- a/c10/core/AllocatorConfig.h +++ /dev/null @@ -1,337 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include - -namespace c10::CachingAllocator { - -// "large" allocations may be packed in 20 MiB blocks -const size_t kLargeBuffer = 20971520; - -// A utility class for tokenizing allocator configuration strings into discrete -// parts. For example, the config string: -// "key1:val1,key2:[val2,val3]" -// is tokenized into: -// "key1", ":", "val1", ",", "key2", ":", "[", "val2", ",", "val3", "]", -// -// Tokens include keys, values, and special characters (':', ',', '[', ']'). -// Whitespace is ignored. -class ConfigTokenizer { - public: - explicit ConfigTokenizer(const std::string& env) { - std::string buffer; - for (char ch : env) { - if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { - if (!buffer.empty()) { - config_.emplace_back(std::move(buffer)); - buffer.clear(); - } - config_.emplace_back(1, ch); - } else if (!std::isspace(static_cast(ch))) { - buffer += ch; - } - } - if (!buffer.empty()) { - config_.emplace_back(std::move(buffer)); - } - } - - const std::string& operator[](size_t i) const { - TORCH_INTERNAL_ASSERT( - i < config_.size(), "Index out of bounds in ConfigTokenizer"); - return config_[i]; - } - - size_t size() const { - return config_.size(); - } - - bool checkToken(size_t i, const std::string& token) const { - checkIndex(i); - return config_[i] == token; - } - - size_t toSizeT(size_t i) const { - checkIndex(i); - return std::stoull(config_[i]); - } - - double toDouble(size_t i) const { - checkIndex(i); - return std::stod(config_[i]); - } - - bool toBool(size_t i) const { - checkIndex(i); - const auto& token = config_[i]; - if (token == "True") { - return true; - } else if (token == "False") { - return false; - } else { - TORCH_CHECK( - false, - "Expected 'True' or 'False' at index ", - i, - " in ConfigTokenizer but got '", - token, - "'"); - } - } - - // Skips the current token group and returns the index of the value token. - // Assumes the current index `i` points to a key name in a key-value pair. - size_t skipKey(size_t i) const { - // Expect a colon after the key - checkToken(++i, ":"); - - ++i; // Move to the value - checkIndex(i); - if (config_[i] != "[") { - // Value is a single token (not a list) -> return its index - return i; - } - - // Skip tokens inside the list until matching ']' - // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) - while (++i < config_.size() && config_[i] != "]") { - } - - TORCH_INTERNAL_ASSERT( - i < config_.size(), - "Expected closing bracket ']' in ConfigTokenizer but reached end of config"); - - return i; // Return the index of the closing ']' - } - - private: - void checkIndex(size_t i) const { - TORCH_INTERNAL_ASSERT( - i < config_.size(), "Index out of bounds in ConfigTokenizer"); - } - - std::vector config_; -}; - -/** - * Note [AcceleratorAllocatorConfig design] - * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * This class configures memory allocation for both device and host memory. A - * single `AcceleratorAllocatorConfig` instance is shared across all accelerator - * backends, such as CUDA and XPU, under the assumption that relevant - * environment variables apply uniformly to all accelerators. Device-specific - * configuration extensions are supported via hooks (see - * `registerDeviceConfigParserHook`). - * - * Recommended design: - * - Place common configurations in `AcceleratorAllocatorConfig`. - * - Extend backend-specific configurations in corresponding device-specific - * classes, such as `CUDAAllocatorConfig`, etc. - * - * Scope: - * - Configuration options must be environment-variable driven. - * - * Naming Convention: - * - Public API names in `AcceleratorAllocatorConfig` should be device-generic. - * - Members prefixed with `pinned_` are specific to the host/pinned allocator. - * - Environment variable names should be generic across backends. - * - Comma-separated key-value pairs in the format: `key:value`. Use square - * brackets `[]` for list values Example: `key1:123, key2:[val1,val2]` - * - * Environment Variables: - * - The primary environment variable for configuration is `PYTORCH_ALLOC_CONF`. - * - For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` is also supported - * with lower priority. - */ - -class C10_API AcceleratorAllocatorConfig { - public: - static AcceleratorAllocatorConfig& instance(); - - C10_DISABLE_COPY_AND_ASSIGN(AcceleratorAllocatorConfig); - AcceleratorAllocatorConfig(AcceleratorAllocatorConfig&&) = delete; - AcceleratorAllocatorConfig& operator=(AcceleratorAllocatorConfig&&) = delete; - ~AcceleratorAllocatorConfig() = default; - - /* Device allocator settings */ - - // Returns the maximum block size (in MB) that is allowed to be split. The - // default is unlimited (all blocks can be split). - static size_t max_split_size() { - return instance().max_split_size_; - } - - // Returns the maximum block size (in MB) that is allowed to be rounded up - // without requiring splitting when searching for a free block. The default is - // 20 MiB. - static size_t max_non_split_rounding_size() { - return instance().max_non_split_rounding_size_; - } - - // Return the number of divisions used when rounding up allocation sizes (in - // MB) to the nearest power-of-2 boundary. - static size_t roundup_power2_divisions(size_t size); - - // Returns the vector of division factors used for rounding up allocation - // sizes. These divisions apply to size intervals between 1MB and 64GB. - static std::vector roundup_power2_divisions() { - return instance().roundup_power2_divisions_; - } - - // Returns the threshold that triggers garbage collection when the ratio of - // used memory to maximum allowed memory exceeds this value. The default is 0, - // meaning no garbage collection is triggered. The value should be in the - // range (0.0, 1.0). - static double garbage_collection_threshold() { - return instance().garbage_collection_threshold_; - } - - // Returns whether the expandable segment feature is enabled. This allows the - // allocator to start with one segment that grows as needed, rather than - // creating a new segment for each allocation. Default is false (expandable - // segments disabled). - static bool use_expandable_segments() { - return instance().use_expandable_segments_; - } - - /* Host allocator settings */ - - // Returns whether the pinned host allocator uses background threads for - // processing events. This is useful for improving performance in scenarios - // where many small allocations are made. Default is false (background threads - // disabled). - static bool pinned_use_background_threads() { - return instance().pinned_use_background_threads_; - } - - /* Settings for both device and host allocator */ - - // Returns the current allocator settings as a string. This string is useful - // to expand device-specific allocator configurations - static std::string last_allocator_settings() { - std::lock_guard lock(instance().last_allocator_settings_mutex_); - return instance().last_allocator_settings_; - } - - // Parses the environment variable `env` to update the allocator settings. - // If the environment variable is not set, it does nothing. - // The configuration string should be a comma-separated list of key-value - // pairs, where each key is a configuration option and the value is the - // corresponding setting. For example: - // "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true" - void parseArgs(const std::string& env); - - // Registers a device-specific configuration parser hook. This allows - // backends to parse additional device-specific configuration options from the - // environment variable. The hook should be a function that takes a string - // (the environment variable value) and parses it to set device-specific - // configuration options. - // The hook will be called when the environment variable is parsed. - // If a hook is already registered, it will be replaced with the new one. - void registerDeviceConfigParserHook( - std::function hook) { - device_config_parser_hook_ = std::move(hook); - } - - // Calls the registered device-specific configuration parser hook with the - // provided environment string. This allows backends to parse additional - // device-specific configuration options from the environment variable. - // If no hook is registered, this function does nothing. - void callDeviceConfigParserHook(const std::string& env) const { - if (device_config_parser_hook_) { - device_config_parser_hook_(env); - } - } - - private: - AcceleratorAllocatorConfig(); - - /* Internal functions for device allocator */ - - // Parse `max_split_size_mb` from environment variable. - size_t parseMaxSplitSize(const ConfigTokenizer& tokenizer, size_t i); - // Parse `max_non_split_rounding_mb` from environment variable. - size_t parseMaxNonSplitRoundingSize( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `garbage_collection_threshold` from environment variable. - size_t parseGarbageCollectionThreshold( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `roundup_power2_divisions` from environment variable. - size_t parseRoundUpPower2Divisions( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `expandable_segments` from environment variable. - size_t parseExpandableSegments(const ConfigTokenizer& tokenizer, size_t i); - - /* Internal functions for host allocator */ - - // Parse `pinned_use_background_threads` from environment variable. - size_t parsePinnedUseBackgroundThreads( - const ConfigTokenizer& tokenizer, - size_t i); - - /* The following members are specifically used for the device allocator. */ - - // The maximum block size that is allowed to be split. - std::atomic max_split_size_{std::numeric_limits::max()}; - // The maximum allowable extra size of a memory block without requiring - // splitting when searching for a free block. - std::atomic max_non_split_rounding_size_{kLargeBuffer}; - // Used to store how memory allocations of different sizes should be rounded - // up to the nearest power of 2 divisions. - std::vector roundup_power2_divisions_; - // The threshold that triggers garbage collection when the ratio of used - // memory to maximum allowed memory exceeds this value. - std::atomic garbage_collection_threshold_{0}; - // A flag to enable expandable segments feature. - std::atomic use_expandable_segments_{false}; - - /* The following members are specifically used for the host allocator. */ - - // A flag to enable background thread for processing events. - std::atomic pinned_use_background_threads_{false}; - - /* The following members are used for both device and host allocator. */ - - // Record the last allocator config environment setting. - std::mutex last_allocator_settings_mutex_; - std::string last_allocator_settings_; - - // Optional hook for parsing additional device-specific allocator settings. - // This allows backends (e.g., CUDA, XPU) to register a custom parser for - // their own environment configuration extensions. - std::function device_config_parser_hook_{nullptr}; -}; - -C10_API inline void setAllocatorSettings(const std::string& env) { - AcceleratorAllocatorConfig::instance().parseArgs(env); - AcceleratorAllocatorConfig::instance().callDeviceConfigParserHook(env); -} - -C10_API inline std::string getAllocatorSettings() { - return AcceleratorAllocatorConfig::instance().last_allocator_settings(); -} - -struct DeviceConfigParserHookRegistry { - explicit DeviceConfigParserHookRegistry( - std::function hook) { - AcceleratorAllocatorConfig::instance().registerDeviceConfigParserHook( - std::move(hook)); - } -}; - -#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(hook) \ - namespace { \ - static at::CachingAllocator::DeviceConfigParserHookRegistry \ - g_device_config_parse_hook_registry_instance(hook); \ - } - -} // namespace c10::CachingAllocator diff --git a/c10/test/core/AllocatorConfig_test.cpp b/c10/test/core/AllocatorConfig_test.cpp deleted file mode 100644 index c051cf4cd4a05..0000000000000 --- a/c10/test/core/AllocatorConfig_test.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include - -#include - -using namespace c10::CachingAllocator; -constexpr size_t kMB = 1024 * 1024ul; - -struct ExtendedAllocatorConfig { - static ExtendedAllocatorConfig& instance() { - static ExtendedAllocatorConfig instance; - return instance; - } - - // Returns the device-specific option value in bytes. - static size_t device_specific_option() { - return instance().device_specific_option_; - } - - void parseArgs(const std::string& env) { - // Parse device-specific options from the environment variable - ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "device_specific_option_mb") { - tokenizer.checkToken(++i, ":"); - device_specific_option_ = tokenizer.toSizeT(++i) * kMB; - } else { - i = tokenizer.skipKey(i); - } - - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); - } - } - } - - private: - // Device-specific option, e.g., memory limit for a specific device. - std::atomic device_specific_option_{0}; -}; - -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK([](const std::string& env) { - ExtendedAllocatorConfig::instance().parseArgs(env); -}) - -TEST(AllocatorConfigTest, allocator_config_test) { - std::string env = - "max_split_size_mb:40," - "max_non_split_rounding_mb:30," - "garbage_collection_threshold:0.5," - "roundup_power2_divisions:[64:8,128:2,256:4,512:2,1024:4,>:1]," - "expandable_segments:True," - "pinned_use_background_threads:True," - "device_specific_option_mb:64"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 40 * kMB); - EXPECT_EQ( - AcceleratorAllocatorConfig::max_non_split_rounding_size(), 30 * kMB); - EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.5); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(32 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(8192 * kMB), 1); - EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), true); - EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), true); - EXPECT_EQ(ExtendedAllocatorConfig::device_specific_option(), 64 * kMB); - - env = - "max_split_size_mb:20," - "max_non_split_rounding_mb:40," - "garbage_collection_threshold:0.8"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 20 * kMB); - EXPECT_EQ( - AcceleratorAllocatorConfig::max_non_split_rounding_size(), 40 * kMB); - EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.8); - - // roundup_power2_divisions knob array syntax - env = "roundup_power2_divisions:[128:8,256:16,512:1,2048:8,>:2]"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 8); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 16); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2); - - // roundup_power2_divisions single value syntax for backward compatibility - env = "roundup_power2_divisions:4"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 4); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 4); - - env = "expandable_segments:False,"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), false); - - env = "pinned_use_background_threads:False"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false); -} From cf3247b74aaeb956b3c2b31d05e965a0aca5a8b4 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 15 Jul 2025 18:47:52 +0000 Subject: [PATCH 061/457] Standalone compile API in _Exporter (#158139) Given an `package: _ExportPackage`, users can get a ready-to-use workspace in `tmp_dir` by calling: ```python package._compiled_and_package( tmp_dir + "/pt2_pacakge_name.pt2", True, package_example_inputs = True ) ``` `tmp_dir` will contains: - `main.cpp` (an example cpp file that create the models, if package_example_inputs is True, it'll also load the example inputs and run the models) - `CMakeLists.txt` - `pt2_pacakge_name/` (this is where the models are) - `pt2_pacakge_name.pt2` - `inputs.pt` files if package_example_inputs is True Remaining TODOs - support loading contants/weights - the `package_example_inputs = True` option only supports a list of Tensors for now - eventually we should remove the `torch` dependency, and use `SlimTensor`/`StableIValue` instead. Test Plan: ``` python test/inductor/test_aot_inductor_package.py -k test_compile_with_exporter ``` Example generated `main.cpp`: ```cpp #include #include #include #include #include #include #include #include "package/data/aotinductor/Plus__default/Plus__default.h" #include "package/data/aotinductor/Minus__default/Minus__default.h" using torch::aot_inductor::AOTInductorModelPlus__default; using torch::aot_inductor::AOTInductorModelMinus__default; using torch::aot_inductor::ConstantHandle; using torch::aot_inductor::ConstantMap; int main(int argc, char* argv[]) { std::string device_str = "cpu"; try { c10::Device device(device_str); // Load input tensors for model Plus__default std::vector input_tensors1; for (int j = 0; j < 2; ++j) { std::string filename = "Plus__default_input_" + std::to_string(j) + ".pt"; std::ifstream in(filename, std::ios::binary); if (!in.is_open()) { std::cerr << "Failed to open file: " << filename << std::endl; return 1; } std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator()); torch::IValue ivalue = torch::pickle_load(buffer); input_tensors1.push_back(ivalue.toTensor().to(device)); } // Load input tensors for model Minus__default std::vector input_tensors2; for (int j = 0; j < 2; ++j) { std::string filename = "Minus__default_input_" + std::to_string(j) + ".pt"; std::ifstream in(filename, std::ios::binary); if (!in.is_open()) { std::cerr << "Failed to open file: " << filename << std::endl; return 1; } std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator()); torch::IValue ivalue = torch::pickle_load(buffer); input_tensors2.push_back(ivalue.toTensor().to(device)); } // Create array of input handles auto input_handles1 = torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors1); auto input_handles2 = torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors2); // Create array for output handles AtenTensorHandle output_handle1; AtenTensorHandle output_handle2; // Create and load models auto constants_map1 = std::make_shared(); auto constants_array1 = std::make_shared>(); auto model1 = AOTInductorModelPlus__default::Create( constants_map1, constants_array1, device_str, "package/data/aotinductor/Plus__default/"); model1->load_constants(); auto constants_map2 = std::make_shared(); auto constants_array2 = std::make_shared>(); auto model2 = AOTInductorModelMinus__default::Create( constants_map2, constants_array2, device_str, "package/data/aotinductor/Minus__default/"); model2->load_constants(); // Run the models torch::aot_inductor::DeviceStreamType stream1 = nullptr; model1->run(&input_handles1[0], &output_handle1, stream1, nullptr); torch::aot_inductor::DeviceStreamType stream2 = nullptr; model2->run(&input_handles2[0], &output_handle2, stream2, nullptr); // Convert output handles to tensors auto output_tensor1 = torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle1, 1); auto output_tensor2 = torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle2, 1); // Validate outputs std::cout << "output_tensor1" << output_tensor1 << std::endl; std::cout << "output_tensor2" << output_tensor2 << std::endl; return 0; } catch (const std::exception &e) { std::cerr << "Error: " << e.what() << std::endl; return 1; } } ``` Rollback Plan: Differential Revision: D78124705 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158139 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor_package.py | 124 ++++++------- test/inductor/test_static_linkage_utils.py | 157 ---------------- torch/export/experimental/__init__.py | 83 ++++++++- torch/export/experimental/_utils.py | 206 +++++++++++++++++++++ 4 files changed, 339 insertions(+), 231 deletions(-) delete mode 100644 test/inductor/test_static_linkage_utils.py create mode 100644 torch/export/experimental/_utils.py diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index a607c4f33e7d3..94fe620a9f18b 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -20,6 +20,7 @@ from torch._inductor.test_case import TestCase from torch._inductor.utils import fresh_cache from torch.export import Dim +from torch.export.experimental import _ExportPackage from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_utils import ( @@ -31,20 +32,6 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -try: - from test_static_linkage_utils import ( - get_static_linkage_main_cpp_file, - get_static_linkage_makelist_file_cpu, - get_static_linkage_makelist_file_cuda, - ) -except ImportError: - from .test_static_linkage_utils import ( - get_static_linkage_main_cpp_file, - get_static_linkage_makelist_file_cpu, - get_static_linkage_makelist_file_cuda, - ) - - def skipif(predicate: Callable[[str, bool], bool], reason: str): def decorator(func): @functools.wraps(func) @@ -153,6 +140,28 @@ def check_package_cpp_only(self: TestCase) -> None: if shutil.which("make") is None: raise unittest.SkipTest("make is not available") + def cmake_compile_and_run(self, base_dir): + custom_env = os.environ.copy() + custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) + build_path = Path(base_dir) / "build" + build_path.mkdir() + subprocess.run( + ["cmake", ".."], + cwd=build_path, + env=custom_env, + check=True, + ) + subprocess.run(["make"], cwd=build_path, check=True) + result = subprocess.run( + ["./build/main"], + cwd=base_dir, + check=True, + capture_output=True, + text=True, + ) + + return result + def cmake_compile(self, model, example_inputs, options, tmp_dir): """ Exports model, compiles it using AOTInductor, extracts the @@ -412,7 +421,7 @@ def forward(self, x, y): @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary - def test_run_static_linkage_model(self): + def test_compile_with_exporter(self): self.check_package_cpp_only() class Model1(torch.nn.Module): @@ -423,64 +432,45 @@ class Model2(torch.nn.Module): def forward(self, x, y): return x - y + def default(*args, **kwargs): + return None + example_inputs = ( - torch.randn(10, 10, device=self.device), - torch.randn(10, 10, device=self.device), + torch.ones(3, 3).to(self.device), + torch.ones(3, 3).to(self.device), ) - model1 = Model1().to(self.device) - model2 = Model2().to(self.device) + package = _ExportPackage() + m1 = Model1() + m2 = Model2() + exporter1 = package._exporter("Plus", m1)._define_overload("default", default) + exporter2 = package._exporter("Minus", m2)._define_overload("default", default) + exporter1(*example_inputs) + exporter2(*example_inputs) - models = [model1, model2] - - i = 0 - model_names = ["Plus", "Minus"] - with ( - tempfile.TemporaryDirectory() as tmp_dir, - ): - for i in range(2): - model = models[i] - # TODO: should be done through _ExportPackage - ep = torch.export.export(model, example_inputs) - - package_path = torch._inductor.aoti_compile_and_package( - ep, - inductor_configs={ - "aot_inductor.compile_standalone": True, - "always_keep_tensor_constants": True, - "aot_inductor.model_name_for_generated_files": model_names[i], - }, + for package_example_inputs in [True, False]: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + package._compiled_and_package( + tmp_dir + "/package.pt2", True, package_example_inputs ) - with ( - zipfile.ZipFile(package_path, "r") as zip_ref, - ): - zip_ref.extractall(tmp_dir) - - file_str = get_static_linkage_main_cpp_file() - with open(Path(tmp_dir) / "main.cpp", "w") as f: - f.write(file_str) - - if self.device == GPU_TYPE: - cmake_file_str = get_static_linkage_makelist_file_cuda() - else: - cmake_file_str = get_static_linkage_makelist_file_cpu() - with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f: - f.write(cmake_file_str) - - build_path = Path(tmp_dir) / "build" - build_path.mkdir() - custom_env = os.environ.copy() - custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) - subprocess.run( - ["cmake", ".."], - cwd=build_path, - env=custom_env, - ) - subprocess.run(["make"], cwd=build_path, check=True) - subprocess.run( - ["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True - ) + # Test compiling generated files + result = self.cmake_compile_and_run(tmp_dir) + if package_example_inputs: + if self.device == GPU_TYPE: + self.assertEqual( + result.stdout, + "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + " 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n", + ) + else: + self.assertEqual( + result.stdout, + "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + " 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n", + ) def test_metadata(self): class Model(torch.nn.Module): diff --git a/test/inductor/test_static_linkage_utils.py b/test/inductor/test_static_linkage_utils.py deleted file mode 100644 index 0a728c1e66df7..0000000000000 --- a/test/inductor/test_static_linkage_utils.py +++ /dev/null @@ -1,157 +0,0 @@ -# Owner(s): ["module: inductor"] -from torch.testing._internal.common_utils import run_tests - - -def get_static_linkage_main_cpp_file(): - return """ -#include -#include -#include -#include -#include - -#include -#include -// Include the AOTInductor headers -#include "Minus.wrapper/data/aotinductor/model/Minus.h" -#include "Plus.wrapper/data/aotinductor/model/Plus.h" -#include -#include - -using torch::aot_inductor::AOTInductorModelMinus; -using torch::aot_inductor::AOTInductorModelPlus; -using torch::aot_inductor::ConstantHandle; -using torch::aot_inductor::ConstantMap; - - -int main(int argc, char* argv[]) { - if (argc < 2) { - std::cerr - << "Usage: ./main " - << std::endl; - return 1; - } - std::string path = argv[1]; - std::string device_str = argv[2]; - try { - torch::Device device(device_str); - - // Create two input tensors (10x10) - auto tensor1 = torch::ones({10, 10}, device); - auto tensor2 = torch::ones({10, 10}, device); - // Create two input tensors (10x10) - auto tensor3 = torch::ones({10, 10}, device); - auto tensor4 = torch::ones({10, 10}, device); - - std::vector input_tensors = {tensor1, tensor2}; - std::vector input_tensors2 = {tensor3, tensor4}; - - // Create array of input handles - auto input_handles1 = - torch::aot_inductor::unsafe_alloc_new_handles_from_tensors( - input_tensors); - auto input_handles2 = - torch::aot_inductor::unsafe_alloc_new_handles_from_tensors( - input_tensors2); - - // Create array for output handle - AtenTensorHandle output_handle1; - AtenTensorHandle output_handle2; - - auto constants_map = std::make_shared(); - auto constants_array = std::make_shared>(); - auto model1 = AOTInductorModelPlus::Create( - constants_map, constants_array, device_str, - path + "Plus.wrapper/data/" - "aotinductor/model/"); - model1->load_constants(); - - auto constants_map2 = std::make_shared(); - auto constants_array2 = std::make_shared>(); - auto model2 = AOTInductorModelMinus::Create( - constants_map2, constants_array2, device_str, - path + "Minus.wrapper/data/" - "aotinductor/model/"); - model2->load_constants(); - - // Run the model - torch::aot_inductor::DeviceStreamType stream1 = nullptr; - torch::aot_inductor::DeviceStreamType stream2 = nullptr; - model1->run(&input_handles1[0], &output_handle1, stream1, nullptr); - model2->run(&input_handles2[0], &output_handle2, stream2, nullptr); - - // Convert output handle to tensor - auto output_tensor1 = - torch::aot_inductor::alloc_tensors_by_stealing_from_handles( - &output_handle1, 1); - auto output_tensor2 = - torch::aot_inductor::alloc_tensors_by_stealing_from_handles( - &output_handle2, 1); - - if (!(torch::all(output_tensor1[0] == 2).item())){ - std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl; - throw std::runtime_error("Tensor does not contain only the expected value 2."); - } - if (!(torch::all(output_tensor2[0] == 0).item())){ - std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl; - throw std::runtime_error("Tensor does not contain only the expected value 0."); - } - - return 0; - } catch (const std::exception &e) { - std::cerr << "Error: " << e.what() << std::endl; - return 1; - } -} - -""" - - -def get_static_linkage_makelist_file_cuda(): - return """ -cmake_minimum_required(VERSION 3.10) -project(TestProject) - -set(CMAKE_CXX_STANDARD 17) - -find_package(Torch REQUIRED) -find_package(CUDA REQUIRED) - -add_subdirectory(Plus.wrapper/data/aotinductor/model/) -add_subdirectory(Minus.wrapper/data/aotinductor/model/) - -# Create executable -add_executable(main main.cpp) - -target_compile_definitions(main PRIVATE USE_CUDA) - -target_link_libraries(main PRIVATE torch cuda - ${CUDA_LIBRARIES} - Plus - Minus) -""" - - -def get_static_linkage_makelist_file_cpu(): - return """ -cmake_minimum_required(VERSION 3.10) -project(TestProject) - -set(CMAKE_CXX_STANDARD 17) - -find_package(Torch REQUIRED) - -add_subdirectory(Plus.wrapper/data/aotinductor/model/) -add_subdirectory(Minus.wrapper/data/aotinductor/model/) - -# Create executable -add_executable(main main.cpp) - -target_link_libraries(main PRIVATE torch - Plus - Minus) -""" - - -if __name__ == "__main__": - run_tests() diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 7f2303f663942..b1c86abc69ce7 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -1,14 +1,22 @@ import copy import dataclasses import functools +import os +import tempfile import types import typing import typing_extensions +import zipfile +from pathlib import Path import torch +from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file from torch.export.exported_program import _decompose_exported_program +__all__ = [] # type: ignore[var-annotated] + + def _copy_graph_module_and_signature( ep: torch.fx.GraphModule, ) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: @@ -333,18 +341,79 @@ def _method_overloads( for overload, ep in method_data.overloads.items(): yield f"{method}:{overload}", ep - def _compiled_and_package(self, f: torch.types.FileLike) -> None: - options = { + def _compiled_and_package( + self, + f: torch.types.FileLike, + standalone: bool = False, + package_example_inputs: bool = False, + ) -> None: + options: dict[str, typing.Any] = { "aot_inductor.package": True, "aot_inductor.package_cpp_only": True, "always_keep_tensor_constants": True, "aot_inductor.package_constants_in_so": False, + "aot_inductor.compile_standalone": standalone, } - weights_map = {} + aoti_files_map = {} + model_names = [] for name, ep in self._method_overloads: - weights = torch._inductor.aot_compile(ep.module(), (), options=options) # type: ignore[arg-type] - weights_map[name] = weights - torch._inductor.package.package.package_aoti( + name = name.replace(":", "__") + model_names.append(name) + options["aot_inductor.model_name_for_generated_files"] = name + aoti_files = torch._inductor.aot_compile( + ep.module(), # type: ignore[arg-type] + ep.example_inputs[0], + kwargs=ep.example_inputs[1], + options=options, + ) + aoti_files_map[name] = aoti_files + + from torch._inductor.package import package + + pt2_path = package.package_aoti( f, - weights_map, # type: ignore[arg-type] + aoti_files_map, # type: ignore[arg-type] + ) + + if not standalone: + return + + assert isinstance(pt2_path, str) + base_directory = os.path.dirname(pt2_path) + package_name = os.path.basename(pt2_path)[:-4] + with ( + zipfile.ZipFile(pt2_path, "r") as zip_ref, + ): + zip_ref.extractall(base_directory) + + example_inputs_map: typing.Optional[dict[str, int]] = ( + {} if package_example_inputs else None + ) + use_cuda = False + for name, ep in self._method_overloads: + name = name.replace(":", "__") + # TODO: also dump kwargs + # TODO: currently only support list of Tensors and they need to be on the same device + if not ep.example_inputs: + continue + for inp in ep.example_inputs[0]: + if isinstance(inp, torch.Tensor) and inp.device.type == "cuda": + # TODO: more carefully determine the device type + use_cuda = True + if package_example_inputs: + assert example_inputs_map is not None + example_inputs_map[name] = len(ep.example_inputs[0]) + for i, t in enumerate(ep.example_inputs[0]): + path = Path(base_directory) / f"{name}_input_{i}.pt" + torch.save(t, path) + + cmake_file_str = _get_make_file(package_name, model_names, use_cuda) + + with open(Path(base_directory) / "CMakeLists.txt", "w") as file: + file.write(cmake_file_str) + + main_file_str = _get_main_cpp_file( + package_name, model_names, use_cuda, example_inputs_map ) + with open(Path(base_directory) / "main.cpp", "w") as file: + file.write(main_file_str) diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py new file mode 100644 index 0000000000000..b91dfbb0db802 --- /dev/null +++ b/torch/export/experimental/_utils.py @@ -0,0 +1,206 @@ +import typing + +from torch._inductor.utils import IndentedBuffer + + +__all__ = [] # type: ignore[var-annotated] + + +def _get_main_cpp_file( + package_name: str, + model_names: list[str], + cuda: bool, + example_inputs_map: typing.Optional[dict[str, int]], +) -> str: + """ + Generates a main.cpp file for AOTInductor standalone models in the specified package. + + Args: + package_name (str): Name of the package containing the models. + model_names (List[str]): List of model names to include in the generated main.cpp. + cuda (bool): Whether to generate code with CUDA support. + example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to + its list of example input tensors. If provided, the generated main.cpp will + load and run these inputs. + + Returns: + str: The contents of the generated main.cpp file as a string. + """ + + ib = IndentedBuffer() + + ib.writelines( + [ + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + ] + ) + if cuda: + ib.writelines( + [ + "#include ", + "#include ", + ] + ) + + for model_name in model_names: + ib.writeline( + f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"' + ) + + ib.newline() + for model_name in model_names: + ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};") + + ib.writelines( + [ + "using torch::aot_inductor::ConstantHandle;", + "using torch::aot_inductor::ConstantMap;", + "", + "int main(int argc, char* argv[]) {", + ] + ) + + with ib.indent(): + ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";') + ib.writeline("try {") + + with ib.indent(): + ib.writeline("c10::Device device(device_str);") + + if example_inputs_map is not None: + # TODO: add device + for i, model_name in enumerate(model_names): + num_inputs = example_inputs_map[model_name] + + ib.writeline(f"// Load input tensors for model {model_name}") + ib.writeline(f"std::vector input_tensors{i + 1};") + ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{") + with ib.indent(): + ib.writeline( + f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";' + ) + ib.writeline("std::ifstream in(filename, std::ios::binary);") + ib.writeline("if (!in.is_open()) {") + with ib.indent(): + ib.writeline( + 'std::cerr << "Failed to open file: " << filename << std::endl;' + ) + ib.writeline("return 1;") + ib.writeline("}") + ib.writeline( + "std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator());" + ) + ib.writeline( + "torch::IValue ivalue = torch::pickle_load(buffer);" + ) + ib.writeline( + f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));" + ) + ib.writeline("}") + ib.newline() + + ib.newline() + ib.writeline("\n// Create array of input handles") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto input_handles{i + 1} =", + f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});", + ] + ) + + ib.writeline("\n// Create array for output handles") + for i in range(len(model_names)): + ib.writeline(f"AtenTensorHandle output_handle{i + 1};") + + ib.writeline("\n// Create and load models") + for i, model_name in enumerate(model_names): + ib.writelines( + [ + f"auto constants_map{i + 1} = std::make_shared();", + f"auto constants_array{i + 1} = std::make_shared>();", + f"auto model{i + 1} = AOTInductorModel{model_name}::Create(", + f" constants_map{i + 1}, constants_array{i + 1}, device_str,", + f' "{package_name}/data/aotinductor/{model_name}/");', + f"model{i + 1}->load_constants();", + ] + ) + + if example_inputs_map is not None: + ib.writeline("\n// Run the models") + for i in range(len(model_names)): + ib.writeline( + f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;" + ) + ib.writeline( + f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);" + ) + + ib.writeline("\n// Convert output handles to tensors") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto output_tensor{i + 1} =", + f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);", + ] + ) + + ib.writeline("\n// Validate outputs") + for i in range(len(model_names)): + ib.writeline( + f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;""" + ) + + ib.writeline("return 0;") + + ib.writelines( + [ + "} catch (const std::exception &e) {", + ] + ) + with ib.indent(): + ib.writeline('std::cerr << "Error: " << e.what() << std::endl;') + ib.writeline("return 1;") + + ib.writeline("}") + ib.writeline("}") + + return ib.getvalue() + + +def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str: + ib = IndentedBuffer() + + ib.writelines( + [ + "cmake_minimum_required(VERSION 3.10)", + "project(TestProject)", + "", + "set(CMAKE_CXX_STANDARD 17)", + "", + "find_package(Torch REQUIRED)", + ] + ) + if cuda: + ib.writeline("find_package(CUDA REQUIRED)") + + ib.newline() + for model_name in model_names: + ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)") + + ib.writeline("\nadd_executable(main main.cpp)") + if cuda: + ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") + + model_libs = " ".join(model_names) + ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + if cuda: + ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") + + return ib.getvalue() From 3beb915004f4e26b1e7c5e7692e6e8ca9b75de46 Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 15 Jul 2025 19:06:14 +0000 Subject: [PATCH 062/457] Update CODEOWNERS for dataloading (#158348) Adding Scott Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/158348 Approved by: https://github.com/scotts, https://github.com/janeyx99 --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 2982b405c3df4..9e01c96c4e9cf 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -136,7 +136,7 @@ torch/profiler/ @sraikund16 test/functorch/test_aotdispatch.py @ezyang @Chillee # Dataloader -torch/utils/data/ @divyanshk @ramanishsingh +torch/utils/data/ @divyanshk @ramanishsingh @scotts # hipify torch/utils/hipify/ @jeffdaily @jithunnair-amd From 148789ddd84f48c189500581f20309c4709e506e Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 06:18:20 -0700 Subject: [PATCH 063/457] Avoid AOTAutogradCache.load in stack trace on cache miss path (#158149) The general context for the upcoming stack of commits is I am attempting to "pipeline" AOTAutograd. Instead of having function f call function g which is the next "stage" of compilation, instead f should return with its outputs, which are then piped to g for the next stage. This will make it easier to implement early exit / resume pipeline without forcing callback structure, which is good for export-style use cases. It also reduces the size of our stack traces, which makes tools like Perfetto happy. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158149 Approved by: https://github.com/jamesjwu --- .../_aot_autograd/autograd_cache.py | 7 ++-- torch/_functorch/aot_autograd.py | 39 ++++++++++--------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 7c06f22905b2e..e66ffefe0a00c 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -1080,8 +1080,7 @@ def clear(): pass @staticmethod - def load( - dispatch_and_compile: Callable, + def try_load( mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], args, aot_config: AOTConfig, @@ -1089,7 +1088,7 @@ def load( boxed_forward_device_index: Optional[BoxedDeviceIndex], local: bool, remote: bool, - ) -> Callable: + ) -> Optional[Callable]: """ Load a result from the cache, and reconstruct a runtime wrapper around the object """ @@ -1198,7 +1197,6 @@ def load( time.time_ns(), forward_symints=symints, ) - compiled_fn = dispatch_and_compile() cache_info.update( { @@ -1232,6 +1230,7 @@ def load( }, payload_fn=lambda: json.dumps(cache_info), ) + return compiled_fn @classmethod diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 56367c0c4676a..9c7f0dc185ec9 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1190,26 +1190,27 @@ def dispatch_and_compile(): ) return compiled_fn - # We only care if the forward will return an OutputCode. - if isinstance(fw_compiler, SerializableAOTDispatchCompiler): - local = should_use_local_autograd_cache() - remote = should_use_remote_autograd_cache() - if local or remote: - set_feature_use("aot_autograd_remote_cache", remote) - compiled_fn = AOTAutogradCache.load( - dispatch_and_compile, - mod, - fake_flat_args, - aot_config, - cudagraphs, - boxed_forward_device_index, - local, - remote, - ) - else: - compiled_fn = dispatch_and_compile() - else: + while True: + # We only care if the forward will return an OutputCode. + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.try_load( + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + if compiled_fn is not None: + break + compiled_fn = dispatch_and_compile() + break if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes From 7afb834f939eccbb3262e646f0922eed070074a7 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 06:18:20 -0700 Subject: [PATCH 064/457] Inline dispatch_and_compile into its call site. (#158150) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158150 Approved by: https://github.com/jamesjwu, https://github.com/wconstab ghstack dependencies: #158149 --- test/dynamo/test_structured_trace.py | 21 +++++------ torch/_functorch/aot_autograd.py | 52 ++++++++++++++-------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 72fe71ace2da0..69f0203adf06f 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -247,6 +247,7 @@ def test_schedule(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -256,7 +257,6 @@ def test_schedule(self): {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -279,6 +279,7 @@ def test_cudagraphs(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -288,7 +289,6 @@ def test_cudagraphs(self): {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -319,6 +319,7 @@ def fn(x, y): {"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -327,7 +328,6 @@ def fn(x, y): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -339,6 +339,7 @@ def fn(x, y): {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -347,7 +348,6 @@ def fn(x, y): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} """, # noqa: B950 @@ -369,6 +369,7 @@ def test_example_fn(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -377,7 +378,6 @@ def test_example_fn(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -424,6 +424,7 @@ def test_example_training_fn(self): {"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} @@ -434,7 +435,6 @@ def test_example_training_fn(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} @@ -506,6 +506,7 @@ def throw(x): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -656,6 +657,7 @@ def forward(self, x): {"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -666,7 +668,6 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -675,6 +676,7 @@ def forward(self, x): {"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -685,7 +687,6 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -716,6 +717,7 @@ def fn(x): {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -724,7 +726,6 @@ def fn(x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -875,6 +876,7 @@ def fn(a): {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -883,7 +885,6 @@ def fn(a): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 9c7f0dc185ec9..3803333948dfe 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1,5 +1,6 @@ # mypy: ignore-errors +import contextlib import itertools from collections.abc import KeysView, Sequence from contextlib import contextmanager, nullcontext @@ -1178,9 +1179,30 @@ def aot_module_simplified( full_args, aot_config, fake_mode, shape_env, ignore_shape_env ) - def dispatch_and_compile(): - functional_call = create_functional_call(mod, params_spec, params_len) - with compiled_autograd._disable(): + with contextlib.ExitStack() as stack: + while True: + # We only care if the forward will return an OutputCode. + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.try_load( + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + if compiled_fn is not None: + break + + functional_call = create_functional_call(mod, params_spec, params_len) + + stack.enter_context(compiled_autograd._disable()) + compiled_fn, _ = create_aot_dispatcher_function( functional_call, fake_flat_args, @@ -1188,29 +1210,7 @@ def dispatch_and_compile(): fake_mode, shape_env, ) - return compiled_fn - - while True: - # We only care if the forward will return an OutputCode. - if isinstance(fw_compiler, SerializableAOTDispatchCompiler): - local = should_use_local_autograd_cache() - remote = should_use_remote_autograd_cache() - if local or remote: - set_feature_use("aot_autograd_remote_cache", remote) - compiled_fn = AOTAutogradCache.try_load( - mod, - fake_flat_args, - aot_config, - cudagraphs, - boxed_forward_device_index, - local, - remote, - ) - if compiled_fn is not None: - break - - compiled_fn = dispatch_and_compile() - break + break if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes From 5606c516fd87e5c3594177e4ca64c3cac7fdafd5 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 15:37:12 +0000 Subject: [PATCH 065/457] [ONNX] Remove legacy Dort (#158258) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158258 Approved by: https://github.com/justinchuby, https://github.com/malfet --- docs/source/onnx.md | 1 - .../source/onnx_dynamo_onnxruntime_backend.md | 11 - docs/source/torch.compiler.md | 2 - test/dynamo/test_backends.py | 5 - torch/_dynamo/backends/onnxrt.py | 67 +- torch/onnx/__init__.py | 11 - torch/onnx/_internal/onnxruntime.py | 1257 ----------------- 7 files changed, 35 insertions(+), 1319 deletions(-) delete mode 100644 docs/source/onnx_dynamo_onnxruntime_backend.md delete mode 100644 torch/onnx/_internal/onnxruntime.py diff --git a/docs/source/onnx.md b/docs/source/onnx.md index ad436748022be..184dc8740c791 100644 --- a/docs/source/onnx.md +++ b/docs/source/onnx.md @@ -87,7 +87,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ onnx_dynamo onnx_ops onnx_verification - onnx_dynamo_onnxruntime_backend onnx_torchscript ``` diff --git a/docs/source/onnx_dynamo_onnxruntime_backend.md b/docs/source/onnx_dynamo_onnxruntime_backend.md deleted file mode 100644 index a59cd4ab919cd..0000000000000 --- a/docs/source/onnx_dynamo_onnxruntime_backend.md +++ /dev/null @@ -1,11 +0,0 @@ -# ONNX Backend for TorchDynamo - -For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. - -```{warning} - The ONNX backend for torch.compile is a rapidly evolving beta technology. -``` - -```{eval-rst} -.. autofunction:: torch.onnx.is_onnxrt_backend_supported -``` \ No newline at end of file diff --git a/docs/source/torch.compiler.md b/docs/source/torch.compiler.md index 5f12670f5e1de..4175da896ccf2 100644 --- a/docs/source/torch.compiler.md +++ b/docs/source/torch.compiler.md @@ -56,8 +56,6 @@ Some of the most commonly used backends include: - CUDA graphs with AOT Autograd. `Read more `__ * - ``torch.compile(m, backend="ipex")`` - Uses IPEX on CPU. `Read more `__ - * - ``torch.compile(m, backend="onnxrt")`` - - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more ` ``` **Inference-only backends** diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 7c4402edeca6b..9d61bbf31acb1 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -8,7 +8,6 @@ import torch._dynamo.backends import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend -from torch._dynamo.backends.onnxrt import has_onnxruntime from torch._dynamo.backends.tvm import has_tvm from torch._dynamo.testing import same from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module @@ -138,10 +137,6 @@ def test_aot_ts(self, device): def test_aot_cudagraphs(self, device): self._check_backend_works("cudagraphs", device) - @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") - def test_onnxrt(self, device): - self._check_backend_works("onnxrt", device) - @unittest.skipIf(not has_tvm(), "requires tvm") def test_tvm(self, device): self._check_backend_works("tvm", device) diff --git a/torch/_dynamo/backends/onnxrt.py b/torch/_dynamo/backends/onnxrt.py index 6830c0409620b..71c5e1765810f 100644 --- a/torch/_dynamo/backends/onnxrt.py +++ b/torch/_dynamo/backends/onnxrt.py @@ -4,35 +4,38 @@ # to the right people, please tag related GitHub issues with `module: onnx`. # # Maintainers' Github IDs: wschin, xadupre -from torch.onnx._internal.onnxruntime import ( - is_onnxrt_backend_supported, - torch_compile_backend, -) - -from .registry import register_backend - - -def has_onnxruntime(): - # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() - return is_onnxrt_backend_supported() - - -if is_onnxrt_backend_supported(): - register_backend(name="onnxrt", compiler_fn=torch_compile_backend) -else: - - def information_displaying_backend(*args, **kwargs): - raise ImportError( - "onnxrt is not registered as a backend. " - "Please make sure all dependencies such as " - "numpy, onnx, onnxscript, and onnxruntime-training are installed. " - "Suggested procedure to fix dependency problem:\n" - " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" - " (2) Open a new python terminal.\n" - " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" - " (4) If it returns `True`, then you can use `onnxrt` backend.\n" - " (5) If it returns `False`, please execute the package importing section in " - "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." - ) - - register_backend(name="onnxrt", compiler_fn=information_displaying_backend) +# from torch.onnx._internal.onnxruntime import ( +# is_onnxrt_backend_supported, +# torch_compile_backend, +# ) + +# from .registry import register_backend + +""" +Placeholder for onnxruntime backend for dynamo +""" + +# def has_onnxruntime(): +# # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() +# return is_onnxrt_backend_supported() + + +# if is_onnxrt_backend_supported(): +# register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +# else: + +# def information_displaying_backend(*args, **kwargs): +# raise ImportError( +# "onnxrt is not registered as a backend. " +# "Please make sure all dependencies such as " +# "numpy, onnx, onnxscript, and onnxruntime-training are installed. " +# "Suggested procedure to fix dependency problem:\n" +# " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" +# " (2) Open a new python terminal.\n" +# " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" +# " (4) If it returns `True`, then you can use `onnxrt` backend.\n" +# " (5) If it returns `False`, please execute the package importing section in " +# "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." +# ) + +# register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 7db778ef08e60..6c301ef294eb1 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -38,8 +38,6 @@ "OnnxExporterError", "ONNXProgram", "enable_fake_mode", - # DORT / torch.compile - "is_onnxrt_backend_supported", ] from typing import Any, Callable, TYPE_CHECKING @@ -51,12 +49,6 @@ from ._internal._exporter_legacy import enable_fake_mode from ._internal.exporter._onnx_program import ONNXProgram -from ._internal.onnxruntime import ( - is_onnxrt_backend_supported, - OrtBackend as _OrtBackend, - OrtBackendOptions as _OrtBackendOptions, - OrtExecutionProvider as _OrtExecutionProvider, -) from ._type_utils import JitScalarType from .errors import OnnxExporterError from .utils import ( @@ -98,10 +90,7 @@ JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" -_OrtBackend.__module__ = "torch.onnx" -_OrtBackendOptions.__module__ = "torch.onnx" enable_fake_mode.__module__ = "torch.onnx" -is_onnxrt_backend_supported.__module__ = "torch.onnx" producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py deleted file mode 100644 index f9550d031fdc3..0000000000000 --- a/torch/onnx/_internal/onnxruntime.py +++ /dev/null @@ -1,1257 +0,0 @@ -# mypy: allow-untyped-defs -import dataclasses -import importlib -import logging -import os -from collections.abc import Mapping, Sequence -from typing import Any, Callable, Final, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias - -import torch -import torch._C -import torch._ops -import torch._prims.executor -import torch.fx -import torch.onnx._internal._lazy_import -from torch._subclasses.fake_tensor import FakeTensor -from torch.fx._compatibility import compatibility -from torch.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.fx.passes.operator_support import OperatorSupport -from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.utils import _pytree - - -if TYPE_CHECKING: - import onnx - import onnxruntime - from onnxruntime.capi import _pybind_state as ORTC - - import torch.onnx - import torch.onnx._internal - import torch.onnx._internal._exporter_legacy - import torch.onnx._internal.fx.decomposition_table - import torch.onnx._internal.fx.passes # noqa: TCH004 - - -_SUPPORT_ONNXRT: Optional[bool] = None - -__all__ = [ - "is_onnxrt_backend_supported", - "torch_compile_backend", - "OrtExecutionProvider", - "OrtBackendOptions", - "OrtBackend", -] - - -def is_onnxrt_backend_supported() -> bool: - """Returns ``True`` if ONNX Runtime dependencies are installed and usable - to support TorchDynamo backend integration; ``False`` otherwise. - - Example:: - - # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import torch - >>> if torch.onnx.is_onnxrt_backend_supported(): - ... @torch.compile(backend="onnxrt") - ... def f(x): - ... return x * x - ... print(f(torch.randn(10))) - ... else: - ... print("pip install onnx onnxscript onnxruntime") - ... - """ - global _SUPPORT_ONNXRT - - if _SUPPORT_ONNXRT is None: - # `onnxruntime` might import a lot of other runtime packages, - # e.g. apex, deepspeed, transformers. - # So lazy-importing onnxruntime to avoid possible circular import. - try: - importlib.import_module("onnxruntime") - importlib.import_module("onnxruntime.capi._pybind_state") - - # This is not use directly in DORT but needed by underlying exporter, - # so we still need to check if it exists. - importlib.import_module("onnxscript") - - import torch.onnx # noqa: F401 - import torch.onnx._internal # noqa: F401 - import torch.onnx._internal._exporter_legacy # noqa: F401 - from torch.onnx._internal.fx import ( # noqa: F401 - decomposition_table, - fx_onnx_interpreter, - passes, - type_utils, - ) - - _SUPPORT_ONNXRT = True - except ImportError: - _SUPPORT_ONNXRT = False - - return _SUPPORT_ONNXRT - - -_dumped_onnx_model: dict[str, int] = {} - - -def _dump_onnx_model( - model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None -) -> str: - """Stores the onnx model into a file. - The name is "{ONNXRT_DUMP_PATH}{N}.onnx" - where *N* is the number of files already stored with - this prefix. - If graph_module is not None, the graph is stored as a string with - the same filename except the extension (.txt). - """ - prefix = os.environ.get("ONNXRT_DUMP_PATH", None) - if not prefix: - return "" - n = _dumped_onnx_model.get(prefix, -1) + 1 - filename = f"{prefix}{n}.onnx" - with open(filename, "wb") as f: - f.write(model_string) - _dumped_onnx_model[prefix] = n - if graph_module is not None: - filename_txt = f"{prefix}{n}.txt" - with open(filename_txt, "w", encoding="utf-8") as f: - f.write(str(graph_module.graph)) - return filename - - -def _infer_default_eps() -> Sequence[str]: - # TODO: select a good default based on the capabilities of the host - # e.g. DML on Windows, etc. - return ["CPUExecutionProvider"] - - -def _nvtx_range_push(name: str): - """If PyTorch is installed with CUDA support, this starts NVTX range. - - Check torch.cuda.nvtx.range_push's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_push(name) - - -def _nvtx_range_pop(): - """If PyTorch is installed with CUDA support, this terminates NVTX range. - - Check torch.cuda.nvtx.range_pop's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_pop() - - -def _get_ort_device_type(device_type: str): - from onnxruntime.capi import _pybind_state as ORTC - - if device_type == "cuda": - return ORTC.OrtDevice.cuda() - if device_type == "cpu": - return ORTC.OrtDevice.cpu() - # ort pytorch device is mapped to NPU OrtDevice type - if device_type == "maia": - return ORTC.OrtDevice.npu() - raise ValueError("Unsupported device type: " + device_type) - - -logger = logging.getLogger(__name__) -# Uncomment the following lines to print out development info. -# logging.basicConfig(level=logging.WARNING) -# logger.setLevel(logging.WARNING) - - -class OrtOperatorSupport(OperatorSupport): - """Operator support for ONNXRuntime backend. - - It has two-level of support decision. One is via support_dict and the other one - is via extra_support_dict. The logic of using support_dict is implemented in - OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. - """ - - def __init__(self, support_dict: set[Any], extra_support_dict: dict[str, Any]): - # Use extra_support_dict[op_name] = None to indicate - # we support op_name with all input types. Otherwise, - # see support_dict (type: SupportDict) in operator_support.py - # for specifying supported types. - super().__init__(extra_support_dict) - self._onnx_support_dict = support_dict - - def is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: - # OperatorSupport.is_node_supported returns True for non-callable nodes. - # Since ORT can't execute them, we return False here to override the base - # behavior. - if node.op not in CALLABLE_NODE_OPS: - return False - # This is the and the only place to decide if aten op is supported. - if node.op == "call_function" and node.target in self._onnx_support_dict: - logger.info( - "support_dict supports node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return True - # If node.target is not in support_dict, we still want to check if torch.jit.script - # can convert it to ONNX equivalence. Let's use base mechanism to do this. - # See extra_support_dict for supported ops. - if super().is_node_supported(submodules, node): - logger.info( - "extra_support_dict supports node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return True - logger.warning( - "support_dict and extra_support_dict don't support node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return False - - -def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: - """ - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -def _infer_ep_from_device(*args) -> tuple[str, ...]: - """Return the first valid device (i.e., GPU or CPU) in argument list.""" - eps = [] - for arg in args: - if hasattr(arg, "device"): - device = arg.device - if device.type == "cuda": - eps.append("CUDAExecutionProvider") - elif device.type == "cpu": - eps.append("CPUExecutionProvider") - return tuple(eps) - - -def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> tuple[Any, ...]: - placeholders = [] - for node in graph_module.graph.nodes: - if node.op == "placeholder": - if hasattr(node, "meta") and "val" in node.meta: - assert isinstance(node.meta["val"], torch.Tensor) - placeholders.append(node) - return tuple(placeholders) - - -def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: - """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" - for node in graph_module.graph.nodes: - if node.op == "output": - # Output node is unique. Let's retrieve output values from - # this node's input list. And then just return. - return node.args[0] - raise ValueError("No output node found in this torch.fx.GraphModule.") - - -def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> tuple[str, ...]: - """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" - flattened_output_args, _ = _pytree.tree_flatten( - _extract_graph_module_outputs(graph_module) - ) - # Output arguments with example value (type: torch.Tensor) in the `graph_module`. - selected_output_args = [ - output_arg.meta["val"] - for output_arg in flattened_output_args - # output_arg must have tensor for its device information. - # Otherwise, skip it. - if (hasattr(output_arg, "meta") and "val" in output_arg.meta) - ] - return _infer_ep_from_device(*selected_output_args) - - -def _sort_eps(eps: tuple[str, ...]) -> tuple[str, ...]: - """Sort execution providers in eps based on pre-set priority.""" - - def get_execution_provider_priority(ep: str) -> int: - if ep == "CPUExecutionProvider": - # Lowest priority. - return 2 - if ep == "CUDAExecutionProvider": - # Higher priority than CPU but lower than - # other specialized EPs. - return 1 - # Highest priority. - return 0 - - unique_eps = set(eps) - return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) - - -def _get_onnx_devices( - values: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple["ORTC.OrtDevice", ...]: - from onnxruntime.capi import _pybind_state as ORTC - - def _device_id_or_zero(device_id: int) -> int: - return device_id or 0 - - def _map_tensor_or_sym_to_device( - value: Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ) -> int: - if isinstance(value, torch.Tensor): - return ORTC.OrtDevice( - _get_ort_device_type(value.device.type), - ORTC.OrtDevice.default_memory(), - _device_id_or_zero(value.device.index), - ) - elif isinstance( - value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) - ): - return ORTC.OrtDevice( - _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 - ) - else: - raise ValueError("Unsupported value type: " + str(type(value))) - - if len(values) > 0: - ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) - return ort_devices - else: - return (_map_tensor_or_sym_to_device(1),) - - -def _get_ortvalues_from_torch_tensors( - tensors: tuple[torch.Tensor, ...], devices: tuple["ORTC.OrtDevice", ...] -) -> tuple[torch.Tensor, ...]: - # TODO(justinchuby): Refactor this function - import numpy as np - from onnxruntime.capi import _pybind_state as ORTC - - torch_dtype_to_numpy_dtype = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.longlong, - torch.bool: np.bool_, - } - ortvalues = ORTC.OrtValueVector() - ortvalues.reserve(len(tensors)) - dtypes = [] - shapes = [] - data_ptrs = [] - - for tensor in tensors: - dtypes.append(torch_dtype_to_numpy_dtype[tensor.dtype]) - shapes.append(tensor.size()) - data_ptrs.append(tensor.data_ptr()) - ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) - return ortvalues - - -def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: - if tensor.is_sparse: - raise ValueError("sparse tensor is not yet supported.") - out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) - return out - - -def _adjust_scalar_from_fx_to_onnx( - dynamo_value: Union[ - torch.Tensor, - int, - float, - bool, - ], - value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] -) -> torch.Tensor: - """Helper function to wrap PyTorch variables as torch.Tensor""" - if ( - isinstance(dynamo_value, torch.Tensor) - and len(value_info.type.tensor_type.shape.dim) == 0 - and dynamo_value.shape == (1,) - ): - # ONNX expect a scalar with empty shape. - # In contrast, PyTorch usually allows implicit - # conversion between shape=() and shape=(1,). - # - # Below, PyTorch's shape (1,) is reshaped to (). - return torch.squeeze(dynamo_value) - elif isinstance(dynamo_value, int): - return torch.tensor(dynamo_value, dtype=torch.int64) - elif isinstance(dynamo_value, float): - return torch.tensor(dynamo_value, dtype=torch.float32) - elif isinstance(dynamo_value, bool): - return torch.tensor(dynamo_value, dtype=torch.bool) - else: - assert isinstance(dynamo_value, torch.Tensor) - return dynamo_value.contiguous() - - -def _adjust_scalar_from_onnx_to_fx( - tensor: torch.Tensor, - prim_value: Union[ - torch.Tensor, - torch.SymInt, - int, - torch.SymFloat, - float, - torch.SymBool, - bool, - ], -) -> Union[ - torch.Tensor, - int, - float, - bool, -]: - """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" - assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." - if isinstance( - prim_value, - (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), - ): - # Convert tensor back to scalar to match Dynamo's expectation. - return tensor.item() - return tensor - - -def _run_onnx_session_with_ortvaluevector( - sess: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - inputs: tuple[torch.Tensor, ...], - input_devices: tuple["ORTC.OrtDevice", ...], - output_names: tuple[str, ...], - outputs: tuple[torch.Tensor, ...], - output_devices: tuple["ORTC.OrtDevice", ...], - preallocate_output: bool, - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - normalized_prim_outputs: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple[Union[torch.Tensor, int, float, bool], ...]: - import onnxruntime - from onnxruntime.capi import _pybind_state as ORTC - - _nvtx_range_push("contiguous") - inputs = tuple( - _adjust_scalar_from_fx_to_onnx(arg, value_info) - for arg, value_info in zip(inputs, input_value_infos) - ) - _nvtx_range_pop() - - _nvtx_range_push("push_back_batch") - ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) - - # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. - # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue - # to torch Tensor transferring the ownership. - if preallocate_output: - pth_outputs = tuple( - _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs - ) - ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) - else: - ort_outputs = ORTC.OrtValueVector() - _nvtx_range_pop() - - _nvtx_range_push("run_with_ortvaluevector") - run_options = onnxruntime.RunOptions() - run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") - sess.run_with_ortvaluevector( - run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices - ) - _nvtx_range_pop() - - # Post-processing step: - # wrap ORT's outputs to the schema represented by - # `prim_output` (obtained by running the original - # torch.fx.GraphModule). - if preallocate_output: - # Profile the ORT-to-PyTorch type cast below - _nvtx_range_push("after run_with_ortvaluevector") - # Outputs are stored on pre-allocated torch.Tensors' memory, - # so this case doesn't need to convert ORTValue to torch.Tensor. - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] - for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) - ) - _nvtx_range_pop() - return pth_outputs - else: - import onnxruntime.training - - # Profile the two ORT-to-PyTorch type casts below - _nvtx_range_push("after run_with_ortvaluevector") - # Map ORTValue to torch.Tensor. - pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( - ort_outputs - ) - # Change some torch.Tensor to int, float, bool. - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] - for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) - ) - _nvtx_range_pop() - return pth_outputs - - -def _run_onnx_session_with_fetch( - sess: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - inputs: tuple[torch.Tensor, ...], - input_devices: tuple["ORTC.OrtDevice", ...], - output_names: tuple[str, ...], - outputs: tuple[torch.Tensor, ...], - output_devices: tuple["ORTC.OrtDevice", ...], - preallocate_output: bool, - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - normalized_prim_outputs: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple[Union[torch.Tensor, int, float, bool], ...]: - import onnxruntime - - inputs = tuple( - _adjust_scalar_from_fx_to_onnx(arg, value_info) - for arg, value_info in zip(inputs, input_value_infos) - ) - feed = { - name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) - for name, tensor in zip(input_names, inputs) - } - ort_outputs = sess.run(output_names, feed) - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx( - torch.from_numpy(value), - prim_output, - ) - for value, prim_output in zip(ort_outputs, normalized_prim_outputs) - ) - return pth_outputs - - -def _from_python_type_to_onnx_tensor_element_type(type: type): - """ - Converts a Python type to the corresponding ONNX tensor element type. - For example, `_from_python_type_to_onnx_tensor_element_type(float)` returns - `onnx.TensorProto.FLOAT`. - - Args: - type (type): The Python type to convert. - - Returns: - int: The corresponding ONNX tensor element type. - - """ - import onnx - - _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { - float: onnx.TensorProto.FLOAT, # type: ignore[attr-defined] - int: onnx.TensorProto.INT64, # type: ignore[attr-defined] - bool: onnx.TensorProto.BOOL, # type: ignore[attr-defined] - } - return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type) - - -class OrtExecutionInfoPerSession: - """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" - - def __init__( - self, - session: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - output_names: tuple[str, ...], - output_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - input_devices: tuple["ORTC.OrtDevice", ...], - output_devices: tuple["ORTC.OrtDevice", ...], - example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor], - ): - # Carrier of ONNX model and its executor. - self.session: onnxruntime.InferenceSession = session - # For the ONNX model stored in self.session, self.input_names[i] is the - # name of the i-th positional input. - self.input_names: tuple[str, ...] = input_names - # self.input_name[i]'s type information is stored in self.input_value_infos[i]. - self.input_value_infos: tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] - # Similar to self.input_names, but for outputs. - self.output_names: tuple[str, ...] = output_names - # Similar to self.input_value_infos but for outputs. - self.output_value_infos: tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] - # For the ONNX model stored in self.session, self.input_devices[i] is the - # i-th positional input's device. - self.input_devices: tuple[ORTC.OrtDevice, ...] = input_devices - # Similar to self.input_devices, but for outputs. - self.output_devices: tuple[ORTC.OrtDevice, ...] = output_devices - # This is the outputs of executing the original torch.fx.GraphModule with example inputs - # (i.e., args passed into OrtBackend._ort_acclerated_call). - self.example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor] = ( - example_outputs - ) - - def is_supported(self, *args): - # TODO(justinchuby): Simplify - import onnx - - _onnx_tensor_element_type_to_torch_dtype = { - onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] - onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] - onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] - onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined] - onnx.TensorProto.INT8: torch.int8, # type: ignore[attr-defined] - onnx.TensorProto.INT16: torch.int16, # type: ignore[attr-defined] - onnx.TensorProto.INT32: torch.int32, # type: ignore[attr-defined] - onnx.TensorProto.INT64: torch.int64, # type: ignore[attr-defined] - } - _torch_dtype_to_onnx_tensor_element_type = { - value: key - for key, value in _onnx_tensor_element_type_to_torch_dtype.items() - } - - # Compare the args and the input schema in ONNX model and - # return the first match. - if len(args) != len(self.input_value_infos): - return False - for arg, value_info in zip(args, self.input_value_infos): - if not isinstance(arg, (torch.Tensor, float, int)): - return False - - # Check Python scalars such as int, float, and bool. - if isinstance(arg, (int, float, bool)): - # Map, e.g., float to onnx.TensorProto.FLOAT. - onnx_dtype = _from_python_type_to_onnx_tensor_element_type(type(arg)) - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - if len(value_info.type.tensor_type.shape.dim) != 0: - return False - continue - - # Check tensor. - onnx_dtype = _torch_dtype_to_onnx_tensor_element_type[arg.dtype] - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): - if isinstance(dim, int) and ( - onnx_dim.dim_value == dim or onnx_dim.dim_param - ): - continue - elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: - continue - else: - return False - return True - - -@dataclasses.dataclass -class OrtExecutionInfoForAllGraphModules: - def __init__(self) -> None: - # All sessions (and their related information) created by exporting the same GraphModule - # with different inputs. - self.execution_info_per_graph_module: dict[ - torch.fx.GraphModule, list[OrtExecutionInfoPerSession] - ] = {} - - def search_reusable_session_execution_info( - self, graph_module: torch.fx.GraphModule, *args - ): - if graph_module not in self.execution_info_per_graph_module: - return None - # All execution information for ONNX models exported from the same `graph_module` - # with different inputs. - candidates = self.execution_info_per_graph_module[graph_module] - - for candidate in candidates: - if candidate.is_supported(*args): - # Returns the first session that accepts this input schema. - return candidate - # No reusable session found. - return None - - def cache_session_execution_info( - self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession - ): - if graph_module not in self.execution_info_per_graph_module: - self.execution_info_per_graph_module[graph_module] = [info] - else: - self.execution_info_per_graph_module[graph_module].append(info) - - -OrtExecutionProvider: TypeAlias = Union[str, tuple[str, Mapping[str, Any]]] -"""Either the name of an ONNX Runtime execution provider as a string or -a 2-tuple of the name and a dictionary of execution provider options. - -Examples:: - - >>> "CPUExecutionProvider" - - >>> ("CUDAExecutionProvider", {"device_id": 3}) - -""" - - -@dataclasses.dataclass(frozen=True) -@compatibility(is_backward_compatible=False) -class OrtBackendOptions: - """Options for constructing an ``OrtBackend``, the ONNX Runtime - backend (``"onnxrt"``) for ``torch.compile``. - - Example:: - - >>> @torch.compile( - ... backend="onnxrt", - ... options=torch.onnx._OrtBackendOptions(...), - ... ) - ... def ort_function(x): - ... return x ** x - """ - - preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None - """An optional sequence of execution providers to be prioritized ahead of any - execution providers that may be inferred (see ``infer_execution_providers``). - """ - - infer_execution_providers: bool = True - """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" - - default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None - """The default fallback execution providers. If not specified, one will be - be selected based on the host environment (most likely ``"CPUExecutionProvider"``). - """ - - # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession - # in order to avoid internal allocation of output buffers in InferenceSession. - # If output ortvalue returned from InferenceSession is allocated internally, - # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. - # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor - # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. - # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, - # and use the preallocated output buffers for InferenceSession not holding any ownership for them. - # TODO(wschin): Make it to inference session level flag. - # See https://github.com/pytorch/pytorch/issues/106869. - preallocate_output: bool = False - """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" - - use_aot_autograd: bool = True - """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend - to support training (i.e., backward graphs are also sent to ``OrtBackend``). - - Symbolic execution is used to capture the forward pass and backward passes as a single graph. - Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used - to split the entire graph into forward sub-graph and backward sub-graph. Finally, both - sub-graphs are compiled by ``OrtBackend``. - """ - - ort_session_options: Optional["onnxruntime.SessionOptions"] = None - """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" - - pre_ort_model_transforms: Optional[ # type: ignore[name-defined] - Sequence[Callable[["onnx.ModelProto"], None]] - ] = None - """A list of graph transforms to be applied to the ONNX model before it - is fed to ONNXRuntime's InferenceSession.""" - - -@compatibility(is_backward_compatible=False) -class OrtBackend: - """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. - - The compiler entry point is OrtBackend.compile, which - 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported - sub-graphs. - 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. - 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. - """ - - def __init__(self, options: Optional[OrtBackendOptions] = None): - from onnxruntime.capi import _pybind_state as ORTC - - import torch.onnx - import torch.onnx._internal._exporter_legacy - import torch.onnx._internal.fx.decomposition_table - - self._options: Final = OrtBackendOptions() if options is None else options - - # options.export_options contains information shared between exporter and DORT. - # For example, they should use the same decomposition table when - # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) - # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model - # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). - # - # Convert user-facing option to internal option used by ONNX exporter - # to access required information. - # Some useful fields: - # - Decomposition table for decomposing FX operators in exporter is - # self._resolved_onnx_exporter_options.decomposition_table. - # - self._resolved_onnx_exporter_options.onnx_registry records what - # aten/prim ops are supported by exporter and their exporters (type: callable). - self._resolved_onnx_exporter_options = ( - torch.onnx._internal._exporter_legacy.ResolvedExportOptions() - ) - - # Given DORT's computation flow: - # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators - # and send them to DORT. - # 2. Then, DORT exports the selected sub-graphs into ONNX. - # 3. Finally DORT calls ORT to do the computation. - # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) - # must use the same support_dict. If the support_dict here contains something not - # supported by exporter, exporter will fails in step 2 since the selected graphs may - # contains unsupported operators such as aten::_who_you_are. - # This restriction is automatically done since DORT and exporter shares the same - # self._resolved_onnx_exporter_options. - support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( - self._resolved_onnx_exporter_options.onnx_registry - ) - - extra_support_dict: dict[str, Any] = { - "getattr": None, - # To send operator.getitem to ORT, add the corresponding string - # recognized by PyTorch's OperatorSupport class. - "_operator.getitem": None, - # To send operator.mul to ORT, add the corresponding string - # recognized by PyTorch's OperatorSupport class. - "_operator.mul": None, - "_operator.add": None, - "_operator.sub": None, - } - - self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) - # TODO(wschin): this is a naive implementation of cache without proper guard - # See https://github.com/pytorch/pytorch/issues/106868. - self._partitioner_cache: dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} - # Conceptually, this filed is a 2-layer dictionary - # GraphModule 0 - # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 1 - # ... - # GraphModule 1 - # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 3 - # ... - # ... - # , which caches all previous compilation result so that we can reuse them. - # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs - # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different - # graphs captured by Dynamo and sent to OrtBackend.compile. - self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() - - self._assert_allclose_to_baseline = False - - self.execution_count = 0 - - # Function which invokes ORT do to the real computation. - self.run = ( - _run_onnx_session_with_ortvaluevector - if hasattr(ORTC.OrtValueVector, "push_back_batch") - else _run_onnx_session_with_fetch - ) - - def _select_eps( - self, graph_module: torch.fx.GraphModule, *args - ) -> Sequence[tuple[str, Mapping[str, Any]]]: - inferred_eps: tuple[str, ...] = () - if self._options.infer_execution_providers: - if eps_from_args := _infer_ep_from_device(*args): - # If user feeds CUDA tensor as input argument, - # we want to use CUDA EP. - # Thus, `eps_from_args` (deduced from input arguments) - # has highest priority. - inferred_eps = eps_from_args - elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): - # If there is no EP in input arguments, we deduce EP from - # graph_module's outputs. Those outputs may come from - # FakeTensorProp or Dynamo's built-in symbolic shape inference. - inferred_eps = eps_from_graph_module - - selected_eps = [] - - for ep in ( - *(self._options.preferred_execution_providers or []), - *_sort_eps(inferred_eps), - *(self._options.default_execution_providers or _infer_default_eps()), - ): - if isinstance(ep, str): - ep = (ep, {}) - elif isinstance(ep, tuple) and ep[1] is None: - ep = (ep[0], {}) - if ep is not None and ep not in selected_eps: - selected_eps.append(ep) - - return selected_eps - - def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): - """This function replaces GraphModule._wrapped_call in compiled model. - - The _wrapped_call is the underlying implementation of forward method. Replacing - it means we delegate the computation to _ort_acclerated_call and therefore - onnxruntime.InferenceSession. - """ - import onnxruntime - - from torch.onnx._internal.fx import fx_onnx_interpreter, passes - - cached_execution_info_per_session = ( - self._all_ort_execution_info.search_reusable_session_execution_info( - graph_module, *args - ) - ) - if cached_execution_info_per_session: - onnx_session = cached_execution_info_per_session.session - input_names = cached_execution_info_per_session.input_names - output_names = cached_execution_info_per_session.output_names - input_value_infos = cached_execution_info_per_session.input_value_infos - output_value_infos = cached_execution_info_per_session.output_value_infos - input_devices = cached_execution_info_per_session.input_devices - output_devices = cached_execution_info_per_session.output_devices - prim_outputs = cached_execution_info_per_session.example_outputs - else: - # It's first time seeing such as graph. Let's make a new session - # (type: onnxruntime.InferenceSession) for it. - - # Generate reference outputs. They are used to indicate output - # tensors' types and devices when calling ORT. - # - # WARNING: The downstream code should not change prim_outputs and - # this backend should always produces output with schema identical to prim_outputs'. - - if self._resolved_onnx_exporter_options.dynamic_shapes: - # No pre-allocation when dynamic shape is enabled. - self.preallocate_output = False - extracted_outputs = _extract_graph_module_outputs(graph_module) - - def maybe_map_to_meta_val(value): - if hasattr(value, "meta") and "val" in value.meta: - # Select outputs with "val" information. Without "val", - # it's not possible access output_arg.meta["val"].device. - return value.meta["val"] - else: - return value - - prim_outputs = _pytree.tree_map( - maybe_map_to_meta_val, extracted_outputs - ) - else: - try: - prim_outputs = FakeTensorProp(graph_module).propagate( - *args, **kwargs - ) - except Exception: - logger.warning("FakeTensorProb failed for %s", graph_module) - # When FakeTensorProp fails, it is not possible to preallocate output buffers - # because the output shapes are not inferred. - self.preallocate_output = False - - # rethrow FakeTensorProb failure because it is not yet currently handled. - raise - - # Create the object to iterate through the nodes in graph one-by-one - # and calls the corresponding ONNX exporter for each node. - fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter() - # Cast FX variables if they will result schema-mismatch when searching - # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, - # but ONNX expects add(double_tensor, double_tensor). - graph_module = passes.InsertTypePromotion(graph_module).run() - # Start the per-node exporting process. It's conceptually a for loop - # scanning through the nodes in the graph. - exported = fx_interpreter.run( - fx_graph_module=graph_module, - onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, - ) - # Convert the exported result to ONNX ModelProto. - onnx_model = exported.to_model_proto( - opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, - ) - - # Modify ONNX model using pre-registered graph transforms. - # They are in-place modifications for avoiding unnecessary - # copy of ONNX initializers. - if self._options.pre_ort_model_transforms: - for transform in self._options.pre_ort_model_transforms: - transform(onnx_model) - - onnx_model_bytes = onnx_model.SerializeToString() - if os.environ.get("ONNXRT_DUMP_PATH", None): - # If not empty, environment variable ONNXRT_DUMP_PATH defined the path - # where generated onnx files should be stored. - # This module keeps a global variables keeping track of the - # stored models. - # If ONNXRT_DUMP_PATH="dumped/dumped_model_" - # The first file name will be 'dumped/dumped_model_0.onnx'. - # For every dumped model, a text file 'dumped/dumped_model_0.txt' - # is created as well to contain the string representing the graph_module. - _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) - - # Initialize a ORT session to execute this ONNX model. - # Note that TorchDynamo assumes all inputs/outputs are on the - # same device, but it's subject to change (very likely with - # dynamic shape support), so we add execution providers - # based on the logic in _select_eps: (explicitly preferred EPs, - # EPs inferred from inputs or graph, and the fallback default EP)/ - # - # TODO(wschin): enable external allocators. - # See https://github.com/pytorch/pytorch/issues/106867 - onnx_session = onnxruntime.InferenceSession( - path_or_bytes=onnx_model_bytes, - sess_options=self._options.ort_session_options, - providers=self._select_eps(graph_module, *args), - ) - - # Cache ORT session. It's reused for the same "graph_module". - # Generate ONNX model and extract its input and output names. - input_names = tuple(input.name for input in onnx_model.graph.input) - output_names = tuple(output.name for output in onnx_model.graph.output) - input_devices = _get_onnx_devices(args) - # Cache devices for inputs and outputs. They are used to invoke - # ORT session. Output devices indicate where (e.g., GPU or CPU) - # to store outputs - if isinstance(prim_outputs, tuple): - output_devices = _get_onnx_devices(prim_outputs) - else: - output_devices = _get_onnx_devices((prim_outputs,)) - - input_value_infos = tuple(input for input in onnx_model.graph.input) - output_value_infos = tuple(output for output in onnx_model.graph.output) - - execution_info_per_session = OrtExecutionInfoPerSession( - session=onnx_session, - input_names=input_names, - input_value_infos=input_value_infos, - output_names=output_names, - output_value_infos=output_value_infos, - input_devices=input_devices, - output_devices=output_devices, - example_outputs=prim_outputs, - ) - - self._all_ort_execution_info.cache_session_execution_info( - graph_module, execution_info_per_session - ) - - self.execution_count += 1 - - # ORT always returns a tuple of outputs. If the original output is a tensor, - # ORT output's first element must be extracted and returned. Otherwise, type - # mismatch may happen in downstream computation. - is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) - normalized_prim_outputs = ( - (prim_outputs,) if is_single_tensor_output else prim_outputs - ) - assert isinstance(normalized_prim_outputs, tuple) - assert all( - isinstance(elem, (torch.Tensor, torch.SymInt, int)) - for elem in normalized_prim_outputs - ) - - _nvtx_range_push("run_onnx_session_with_ortvaluevector") - onnx_outputs = self.run( - onnx_session, - input_names, - args, - input_devices, - output_names, - normalized_prim_outputs, - output_devices, - self._options.preallocate_output, - input_value_infos, - normalized_prim_outputs, - ) - _nvtx_range_pop() - - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute( - graph_module, *args, executor="aten" - ) - normalized_baseline_ouptuts = ( - (baseline_outputs,) if is_single_tensor_output else baseline_outputs - ) - # Ensure every output tensor is close to the corresponding baseline. - for onnx_output, baseline_output in zip( - onnx_outputs, normalized_baseline_ouptuts - ): - torch.testing.assert_close(onnx_output, baseline_output) - return onnx_outputs[0] if is_single_tensor_output else onnx_outputs - - def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - # Deferred import since CapabilityBasedPartitioner is not decorated with - # @compatibility; importing it at the module level will result in the test - # failing: pytest test/test_fx.py -k test_public_api_surface - # because this module is imported into torch.onnx. - from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner - - # FX graph based partitioning based on ONNX supported ops. - # Given a graph module - # GraphModule0 - # node_0 - # node_1 - # node_2 - # node_3 - # node_4 - # If only node_2 is not supported by ONNX, this graph module will be partitioned into - # GraphModule0 - # GraphModule1 - # node_0 - # node_1 - # node_2 - # GraphModule2 - # node_3 - # node_4 - # by calling CapabilityBasedPartitioner.partition_and_fuse. - # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) - # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. - if graph_module in self._partitioner_cache: - partitioned_prim_graph_module = self._partitioner_cache[graph_module] - else: - prim_graph_module = graph_module - partitioner = CapabilityBasedPartitioner( - prim_graph_module, - self._supported_ops, - allows_single_node_partition=True, - ) - partitioned_prim_graph_module = partitioner.partition_and_fuse() - self._partitioner_cache[graph_module] = partitioned_prim_graph_module - - # Overriding fused_module's __call__() function with ort_acclerated_call() - # This loop goes through all graph partitions (each of them is an ONNX-representable graph) - # and override their _wrapped_call function with _ort_accelerated_call. - # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. - for node in partitioned_prim_graph_module.graph.nodes: - # TODO(wschin): use a better way to identify fused submodule - # See https://github.com/pytorch/pytorch/issues/106872. - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(partitioned_prim_graph_module, node.name) - # self.ort_acclerated_call is responsible for exporting graph to ONNX, - # creating ORT session, and running ORT session. - fused_module._wrapped_call = self._ort_acclerated_call - - return partitioned_prim_graph_module - - def __call__( - self, graph_module: torch.fx.GraphModule, args - ) -> torch.fx.GraphModule: - """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler - will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, - the ``compile`` method is invoked directly.""" - if self._options.use_aot_autograd: - from functorch.compile import min_cut_rematerialization_partition - from torch._dynamo.backends.common import aot_autograd - - return aot_autograd( - fw_compiler=self.compile, - partition_fn=min_cut_rematerialization_partition, - decompositions=self._resolved_onnx_exporter_options.decomposition_table, - )(graph_module, args) - - return self.compile(graph_module, args) - - __instance_cache_max_count: Final = 8 - __instance_cache: Final[list["OrtBackend"]] = [] - - @staticmethod - def get_cached_instance_for_options( - options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, - ) -> "OrtBackend": - """Returns a possibly cached instance of an ``OrtBackend``. If an existing - backend was created previously through this function with the same options, - it will be returned. Otherwise a new backend will be created, cached, and - returned. - - Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` - will always be returned, since ``onnxruntime.SessionOptions`` cannot - participate in caching.""" - - def reusable(a: OrtBackendOptions, b: OrtBackendOptions): - if ( - a.preferred_execution_providers != b.preferred_execution_providers - or a.infer_execution_providers != b.infer_execution_providers - or a.default_execution_providers != b.default_execution_providers - or a.preallocate_output != b.preallocate_output - or a.use_aot_autograd != b.use_aot_autograd - or a.pre_ort_model_transforms != b.pre_ort_model_transforms - ): - return False - - # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, - # and holds too much potential state to reasonably check manually; - # ort_session_options is provided at all, the backend does not participate - # in caching. - if a.ort_session_options is not None or b.ort_session_options is not None: - return False - - return True - - if not isinstance(options, OrtBackendOptions): - options = OrtBackendOptions(**(options or {})) - - backend = next( - (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), - None, - ) - - if backend is None: - assert ( - len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count - ), ( - f"No more than {OrtBackend.__instance_cache_max_count} instances of " - f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " - "to pass to `torch.compile`. " - "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " - "for discussion." - ) - OrtBackend.__instance_cache.append(backend := OrtBackend(options)) - - return backend - - @staticmethod - def clear_cached_instances(): - OrtBackend.__instance_cache.clear() - - @staticmethod - def get_cached_instances(): - return tuple(OrtBackend.__instance_cache) - - -@compatibility(is_backward_compatible=False) -def torch_compile_backend( - graph_module: torch.fx.GraphModule, - args, - *, - options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, -): - return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) From 4657a84bc55b6ce12f21706de2b90e1d43784f57 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Tue, 15 Jul 2025 19:22:26 +0000 Subject: [PATCH 066/457] [Optimus][fp8_activation_quantization] Only log when there's some node to be quantized (#158129) Summary: We add some extra check on whether there's some node has been marked as should quantize, otherwise we skip the quantizaton and tlparse log. Rollback Plan: Differential Revision: D78173788 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158129 Approved by: https://github.com/Skylion007, https://github.com/avicizhu --- torch/_functorch/partitioners.py | 87 ++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 7b36092e09eb3..c666a924b468a 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -684,47 +684,11 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1 -def enable_activation_quantization( - saved_values: list[fx.Node], +def perform_fp8_activation_quantization( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, - static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, + bwd_module_inputs: dict[str, fx.Node], ) -> None: - if ( - inductor_config.post_grad_fusion_options.get( - "activation_quantization_aten_pass", None - ) - is None - ): - return - - static_input_names = ( - [node.name for node in static_lifetime_input_nodes] - if static_lifetime_input_nodes - else [] - ) - saved_values_names = {node.name: node for node in saved_values} - if torch._inductor.config.post_grad_fusion_options[ - "activation_quantization_aten_pass" - ].get("exclude_primals", False): - saved_values_names = { - node.name: node for node in saved_values if "primals" not in node.name - } - fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] - bwd_module_inputs = { - node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") - } - for node in fwd_module_outputs: - if node.name in saved_values_names and should_quantize(node): - if node.name in static_input_names: - log.debug("Skipping quantization of static input %s: ", node.name) - continue - node.meta["saved_for_quantization"] = True - node.meta["dequant_type"] = node.meta["val"].dtype - # some of the fwd outputs and bwd inputs are not share the same object - bwd_module_inputs[node.name].meta["saved_for_quantization"] = True - bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype - trace_structured( "artifact", metadata_fn=lambda: { @@ -808,6 +772,53 @@ def enable_activation_quantization( ) +def enable_activation_quantization( + saved_values: list[fx.Node], + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> None: + if ( + inductor_config.post_grad_fusion_options.get( + "activation_quantization_aten_pass", None + ) + is None + ): + return + + static_input_names = ( + [node.name for node in static_lifetime_input_nodes] + if static_lifetime_input_nodes + else [] + ) + saved_values_names = {node.name: node for node in saved_values} + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("exclude_primals", False): + saved_values_names = { + node.name: node for node in saved_values if "primals" not in node.name + } + fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + bwd_module_inputs = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + should_perform_fp8_quant = False + for node in fwd_module_outputs: + if node.name in saved_values_names and should_quantize(node): + if node.name in static_input_names: + log.debug("Skipping quantization of static input %s: ", node.name) + continue + node.meta["saved_for_quantization"] = True + node.meta["dequant_type"] = node.meta["val"].dtype + # some of the fwd outputs and bwd inputs are not share the same object + bwd_module_inputs[node.name].meta["saved_for_quantization"] = True + bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype + should_perform_fp8_quant = True + + if should_perform_fp8_quant: + perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs) + + def _extract_fwd_bwd_modules( joint_module: fx.GraphModule, saved_values: list[fx.Node], From 011026205a9d4c38458130f8ca242028f6184bf0 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 15 Jul 2025 19:31:00 +0000 Subject: [PATCH 067/457] make node source hashable (#158322) Summary: as title Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78296410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322 Approved by: https://github.com/yushangdi --- test/fx/test_fx_traceback.py | 38 ++++++++++++++++++++++++++++++++++++ torch/fx/traceback.py | 13 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index e11ee19daaac4..f02bc5a2e1592 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -74,6 +74,9 @@ def test_node_source(self): ) self.assertEqual(node_source1, node_source2) + # Test hash function - equivalent objects should have same hash + self.assertEqual(hash(node_source1), hash(node_source2)) + # Test two node sources are not same node_source3 = NodeSource( node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE @@ -83,6 +86,41 @@ def test_node_source(self): ) self.assertNotEqual(node_source3, node_source4) + # Test hash function - different objects should have different hash + self.assertNotEqual(hash(node_source3), hash(node_source4)) + + # Test that equivalent NodeSource objects can be used in sets and dicts + node_set = {node_source1, node_source2} + self.assertEqual(len(node_set), 1) # Should only contain one unique element + + node_dict = {node_source1: "value1", node_source2: "value2"} + self.assertEqual(len(node_dict), 1) # Should only contain one key + self.assertEqual(node_dict[node_source1], "value2") # Last value should win + + # Test with more complex NodeSource objects + node_source_with_node = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + node_source_with_node_copy = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be equal and have same hash + self.assertEqual(node_source_with_node, node_source_with_node_copy) + self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy)) + + # Test with different actions + node_source_replace = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE + ) + node_source_create = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be different and have different hashes + self.assertNotEqual(node_source_replace, node_source_create) + self.assertNotEqual(hash(node_source_replace), hash(node_source_create)) + def test_graph_provenance(self): def check_node_source(node_source_dict, name, pass_name, action): self.assertEqual(node_source_dict["name"], name) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 9f316191a2302..97391d567aba8 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -128,6 +128,19 @@ def __eq__(self, other: object): return False return self.to_dict() == other.to_dict() + def __hash__(self): + # Create a hash based on the dictionary representation + # We need to convert the dict to a hashable form + def _make_hashable(obj): + if isinstance(obj, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, list): + return tuple(_make_hashable(item) for item in obj) + else: + return obj + + return hash(_make_hashable(self.to_dict())) + @compatibility(is_backward_compatible=False) @contextmanager From 250ae2531c55dcc50f558ec739941324e3f9a4d4 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sun, 13 Jul 2025 11:51:16 -0700 Subject: [PATCH 068/457] Fix types in graphs.py (#158192) Added type annotations for torch/cuda/graphs.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192 Approved by: https://github.com/oulgen --- torch/_C/__init__.pyi.in | 5 +- torch/_inductor/cudagraph_trees.py | 16 +-- torch/autograd/function.py | 24 ++-- torch/cuda/__init__.py | 5 +- torch/cuda/graphs.py | 163 +++++++++++++++++++-------- torch/onnx/_internal/_lazy_import.py | 2 +- 6 files changed, 145 insertions(+), 70 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 69d90d4e7a1fb..a5c4d390ee36d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -40,6 +40,7 @@ from torch._C import ( ) from torch._prims_common import DeviceLikeType from torch.autograd.graph import Node as _Node +from torch.cuda import _POOL_HANDLE from torch.fx.node import Node as FxNode from torch.package import PackageExporter from torch.storage import TypedStorage, UntypedStorage @@ -2289,7 +2290,7 @@ class _CUDAGraph: def __new__(cls, keep_graph: _bool = ...) -> Self: ... def capture_begin( self, - pool: tuple[_int, _int] | None = ..., + pool: _POOL_HANDLE | None = ..., capture_error_mode: str = "global", ) -> None: ... def capture_end(self) -> None: ... @@ -2297,7 +2298,7 @@ class _CUDAGraph: def register_generator_state(self, Generator) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... - def pool(self) -> tuple[_int, _int]: ... + def pool(self) -> _POOL_HANDLE: ... def enable_debug_mode(self) -> None: ... def debug_dump(self, debug_path: str) -> None: ... def raw_cuda_graph(self) -> _int: ... diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index bdc201803fb60..3b3dea909cd24 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -90,6 +90,7 @@ from torch._guards import CompileId from torch._inductor.utils import InputType + from torch.cuda import _POOL_HANDLE from torch.types import _bool StorageWeakRefPointer = int @@ -817,7 +818,7 @@ def __init__( id: GraphID, parent: Optional[CUDAGraphNode], inputs: list[InputType], - cuda_graphs_pool: tuple[int, int], + cuda_graphs_pool: _POOL_HANDLE, device_index: int, stack_traces: Optional[StackTraces], stream: torch.cuda.Stream, @@ -1228,6 +1229,7 @@ def all_outputs_are_dead(self) -> bool: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: "Record the model" + assert self.graph is not None def static_input_iter() -> Generator[torch.Tensor, None, None]: for i in self.wrapped_function.static_input_idxs: @@ -1310,13 +1312,11 @@ def _add_first_outputs( self.output_storage_alias.append(UnaliasedStorage) continue - ( - torch._check( - o.is_cuda or o.untyped_storage().data_ptr() == 0, - lambda: ( - "Expected all cuda outputs in cuda graph recording. Non cuda output " - f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" - ), + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" ), ) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index b8036a5235b91..ac3aad9f93b59 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,8 +4,8 @@ import itertools import warnings from collections import OrderedDict -from typing import Any, Optional -from typing_extensions import deprecated +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import Concatenate, deprecated, ParamSpec import torch import torch._C as _C @@ -29,6 +29,10 @@ # This is incremented in FunctionMeta during class definition AUTOGRAD_FUNCTION_COUNTER = itertools.count() +_T = TypeVar("_T") +_R = TypeVar("_R") +_P = ParamSpec("_P") + # Formerly known as: _ContextMethodMixin class FunctionCtx: @@ -595,11 +599,13 @@ def _is_setup_context_defined(fn): return fn != _SingleLevelFunction.setup_context -def once_differentiable(fn): +def once_differentiable( + fn: Callable[Concatenate[_T, _P], _R], +) -> Callable[Concatenate[_T, _P], _R]: @functools.wraps(fn) - def wrapper(ctx, *args): + def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: with torch.no_grad(): - outputs = fn(ctx, *args) + outputs = fn(ctx, *args, **kwargs) if not torch.is_grad_enabled(): return outputs @@ -620,12 +626,14 @@ def wrapper(ctx, *args): return outputs if not isinstance(outputs, tuple): - outputs = (outputs,) + outputs_ = (outputs,) + else: + outputs_ = outputs err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked " b"with @once_differentiable", - len(outputs), + len(outputs_), ) # Create aliases of each output that has requires_grad=True. We need @@ -637,7 +645,7 @@ def fake_requires_grad(var): var.requires_grad = True return var - return err_fn(*[fake_requires_grad(v) for v in outputs]) + return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value] return wrapper diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index fc9d09ce63a67..5b85c91d2c208 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -18,7 +18,7 @@ import traceback import warnings from functools import lru_cache -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union import torch import torch._C @@ -1777,6 +1777,9 @@ def _compile_kernel( from . import amp, jiterator, nvtx, profiler, sparse, tunable +_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int]) + + __all__ = [ # Typed storage and tensors "BFloat16Storage", diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index b58a7808593dc..b1d1e4f8c478a 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,12 +1,34 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import gc import typing +from typing import Callable, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar import torch +from torch import Tensor + + +if TYPE_CHECKING: + # importing _POOL_HANDLE at runtime toplevel causes an import cycle + from torch.cuda import _POOL_HANDLE from .._utils import _dummy_type +__all__ = [ + "is_current_stream_capturing", + "graph_pool_handle", + "CUDAGraph", + "graph", + "make_graphed_callables", +] + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") @@ -22,7 +44,7 @@ ) -def is_current_stream_capturing(): +def is_current_stream_capturing() -> bool: r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. If a CUDA context does not exist on the current device, returns False without initializing the context. @@ -31,7 +53,7 @@ def is_current_stream_capturing(): # Python shim helps Sphinx process docstrings more reliably. -def graph_pool_handle(): +def graph_pool_handle() -> _POOL_HANDLE: r"""Return an opaque token representing the id of a graph memory pool. See :ref:`Graph memory management`. @@ -39,7 +61,7 @@ def graph_pool_handle(): .. warning:: This API is in beta and may change in future releases. """ - return _graph_pool_handle() + return torch.cuda._POOL_HANDLE(_graph_pool_handle()) # Python shim helps Sphinx process docstrings more reliably. @@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph): """ - def __new__(cls, keep_graph=False): + def __new__(cls, keep_graph: bool = False) -> Self: return super().__new__(cls, keep_graph) - def capture_begin(self, pool=None, capture_error_mode="global"): + def capture_begin( + self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global" + ) -> None: r"""Begin capturing CUDA work on the current stream. Typically, you shouldn't call ``capture_begin`` yourself. @@ -92,7 +116,7 @@ def capture_begin(self, pool=None, capture_error_mode="global"): """ # noqa: B950 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) - def capture_end(self): + def capture_end(self) -> None: r"""End CUDA graph capture on the current stream. After ``capture_end``, ``replay`` may be called on this instance. @@ -103,7 +127,7 @@ def capture_end(self): """ super().capture_end() - def instantiate(self): + def instantiate(self) -> None: r"""Instantiate the CUDA graph. Will be called by ``capture_end`` if ``keep_graph=False``, or by ``replay`` if ``keep_graph=True`` and ``instantiate`` has not already been @@ -112,15 +136,15 @@ def instantiate(self): """ super().instantiate() - def replay(self): + def replay(self) -> None: r"""Replay the CUDA work captured by this graph.""" super().replay() - def reset(self): + def reset(self) -> None: r"""Delete the graph currently held by this instance.""" super().reset() - def pool(self): + def pool(self) -> _POOL_HANDLE: r"""Return an opaque token representing the id of this graph's memory pool. This id can optionally be passed to another graph's ``capture_begin``, @@ -128,11 +152,11 @@ def pool(self): """ return super().pool() - def enable_debug_mode(self): + def enable_debug_mode(self) -> None: r"""Enable debugging mode for CUDAGraph.debug_dump.""" return super().enable_debug_mode() - def debug_dump(self, debug_path): + def debug_dump(self, debug_path: str) -> None: r""" Arguments: debug_path (required): Path to dump the graph to. @@ -142,7 +166,7 @@ def debug_dump(self, debug_path): """ return super().debug_dump(debug_path) - def raw_cuda_graph(self): + def raw_cuda_graph(self) -> int: r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ @@ -180,13 +204,13 @@ class graph: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 """ # noqa: B950 - default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + default_capture_stream: Optional[torch.cuda.Stream] = None def __init__( self, - cuda_graph, - pool=None, - stream=None, + cuda_graph: CUDAGraph, + pool: Optional[_POOL_HANDLE] = None, + stream: Optional[torch.cuda.Stream] = None, capture_error_mode: str = "global", ): # Lazy-init of default_capture_stream helps avoid circular-import errors. @@ -195,7 +219,9 @@ def __init__( if self.__class__.default_capture_stream is None: self.__class__.default_capture_stream = torch.cuda.Stream() - self.pool = () if pool is None else (pool,) + self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( + () if pool is None else (pool,) + ) self.capture_stream = ( stream if stream is not None else self.__class__.default_capture_stream ) @@ -204,7 +230,7 @@ def __init__( self.cuda_graph = cuda_graph self.capture_error_mode = capture_error_mode - def __enter__(self): + def __enter__(self) -> None: # Free as much memory as we can for the graph torch.cuda.synchronize() gc.collect() @@ -215,18 +241,47 @@ def __enter__(self): self.stream_ctx.__enter__() self.cuda_graph.capture_begin( - *self.pool, capture_error_mode=self.capture_error_mode + # type: ignore[misc] + *self.pool, + capture_error_mode=self.capture_error_mode, ) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *args: object) -> None: self.cuda_graph.capture_end() - self.stream_ctx.__exit__(exc_type, exc_value, traceback) + self.stream_ctx.__exit__(*args) # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() +_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]] + + +@overload def make_graphed_callables( - callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None -): + callables: _ModuleOrCallable, + sample_args: tuple[Tensor, ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> _ModuleOrCallable: ... + + +@overload +def make_graphed_callables( + callables: tuple[_ModuleOrCallable, ...], + sample_args: tuple[tuple[Tensor, ...], ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> tuple[_ModuleOrCallable, ...]: ... + + +def make_graphed_callables( + callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]], + sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]: r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. Each graphed callable's forward pass runs its source callable's @@ -300,14 +355,17 @@ def make_graphed_callables( just_one_callable = False + _sample_args: tuple[tuple[Tensor, ...], ...] if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) - sample_args = (sample_args,) + _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),) + else: + _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args) flatten_sample_args = [] - for c, args in zip(callables, sample_args): + for c, args in zip(callables, _sample_args): if isinstance(c, torch.nn.Module): assert ( len(c._backward_hooks) == 0 @@ -352,7 +410,7 @@ def make_graphed_callables( torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): for func, args, static_input_surface in zip( - callables, sample_args, per_callable_static_input_surfaces + callables, _sample_args, per_callable_static_input_surfaces ): grad_inputs, outputs, outputs_grad = None, None, None for _ in range(num_warmup_iters): @@ -382,11 +440,11 @@ def make_graphed_callables( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + func_outputs = func(*args) - flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) + flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_output_unflatten_spec.append(spec) @@ -438,19 +496,19 @@ def make_graphed_callables( # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( - fwd_graph, - bwd_graph, - module_params, - len_user_args, - output_unflatten_spec, - static_input_surface, - static_outputs, - static_grad_outputs, - static_grad_inputs, - ): + fwd_graph: CUDAGraph, + bwd_graph: CUDAGraph, + module_params: tuple[torch.nn.Parameter, ...], + len_user_args: int, + output_unflatten_spec: torch.utils._pytree.TreeSpec, + static_input_surface: tuple[Tensor, ...], + static_outputs: tuple[Tensor, ...], + static_grad_outputs: tuple[Optional[Tensor], ...], + static_grad_inputs: tuple[Tensor, ...], + ) -> Callable[..., object]: class Graphed(torch.autograd.Function): @staticmethod - def forward(ctx, *inputs): + def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]: # At this stage, only the user args may (potentially) be new tensors. for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): @@ -461,7 +519,7 @@ def forward(ctx, *inputs): @staticmethod @torch.autograd.function.once_differentiable - def backward(ctx, *grads): + def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]: assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -477,7 +535,7 @@ def backward(ctx, *grads): b.detach() if b is not None else b for b in static_grad_inputs ) - def functionalized(*user_args): + def functionalized(*user_args: object) -> object: # Runs the autograd function with inputs == all inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. @@ -488,7 +546,7 @@ def functionalized(*user_args): return functionalized # Put together the final graphed callables - ret = [] + ret: list[_ModuleOrCallable] = [] for i, func in enumerate(callables): graphed = make_graphed_autograd_function( fwd_graphs[i], @@ -504,20 +562,25 @@ def functionalized(*user_args): if isinstance(func, torch.nn.Module): - def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): - def new_fwd(*user_args): + def make_graphed_forward( + func: torch.nn.Module, + graph_training_state: bool, + graphed: Callable[_P, _R], + orig_fwd: Callable[_P, _R], + ) -> Callable[_P, _R]: + def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R: # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method if func.training == graph_training_state: - return graphed(*user_args) + return graphed(*user_args, **user_kwargs) else: - return orig_fwd(*user_args) + return orig_fwd(*user_args, **user_kwargs) return new_fwd func.forward = make_graphed_forward( func, func.training, graphed, func.forward - ) # type: ignore[assignment] + ) ret.append(func) else: ret.append(graphed) diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 3557ef099309e..7cde0bd35177f 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -28,7 +28,7 @@ def __getattr__(self, attr: str) -> object: # NOTE: Add additional used imports here. if TYPE_CHECKING: import onnx - import onnx_ir # type: ignore[import-untyped] + import onnx_ir # type: ignore[import-untyped, import-not-found] import onnxscript import onnxscript._framework_apis.torch_2_8 as onnxscript_apis From 30587195d314eb5eb02ce63f39a9be4c943629ef Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 15 Jul 2025 19:52:53 +0000 Subject: [PATCH 069/457] Migrate c10/macros/cmake_macros.h.in to torch/headeronly (#158035) Summary: As above, also changes a bunch of the build files to be better Test Plan: internal and external CI did run buck2 build fbcode//caffe2:torch and it succeeded Rollback Plan: Reviewed By: swolchok Differential Revision: D78016591 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158035 Approved by: https://github.com/swolchok --- .bazelrc | 2 +- CMakeLists.txt | 1 + c10/BUCK.oss | 2 - c10/CMakeLists.txt | 20 +++--- c10/macros/Export.h | 2 +- c10/macros/Macros.h | 2 +- c10/macros/build.bzl | 18 +----- c10/macros/cmake_macros.h | 5 ++ c10/ovrsource_defs.bzl | 52 +-------------- tools/bazel.bzl | 2 +- torch/CMakeLists.txt | 1 - torch/headeronly/BUCK.oss | 26 ++++++++ torch/headeronly/BUILD.bazel | 10 +-- torch/headeronly/CMakeLists.txt | 34 ++++++++++ torch/headeronly/build.bzl | 11 ++++ torch/headeronly/macros/BUILD.bazel | 4 ++ torch/headeronly/macros/build.bzl | 28 +++++++++ .../macros/cmake_configure_file.bzl | 0 .../headeronly}/macros/cmake_macros.h.in | 2 +- torch/headeronly/ovrsource_defs.bzl | 63 ++++++++++++++++--- 20 files changed, 183 insertions(+), 102 deletions(-) create mode 100644 c10/macros/cmake_macros.h create mode 100644 torch/headeronly/BUCK.oss create mode 100644 torch/headeronly/CMakeLists.txt create mode 100644 torch/headeronly/build.bzl create mode 100644 torch/headeronly/macros/BUILD.bazel create mode 100644 torch/headeronly/macros/build.bzl rename {c10 => torch/headeronly}/macros/cmake_configure_file.bzl (100%) rename {c10 => torch/headeronly}/macros/cmake_macros.h.in (80%) diff --git a/.bazelrc b/.bazelrc index 7581b52430211..fc2995dc838c5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,7 +2,7 @@ build --cxxopt=--std=c++17 build --copt=-I. # Bazel does not support including its cc_library targets as system # headers. We work around this for generated code -# (e.g. c10/macros/cmake_macros.h) by making the generated directory a +# (e.g. torch/headeronly/macros/cmake_macros.h) by making the generated directory a # system include path. build --copt=-isystem --copt bazel-out/k8-fastbuild/bin build --copt=-isystem --copt bazel-out/darwin-fastbuild/bin diff --git a/CMakeLists.txt b/CMakeLists.txt index 99c0b9e0ea0c9..d1f8a13fb9fd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1245,6 +1245,7 @@ if(USE_MIMALLOC AND USE_MIMALLOC_ON_MKL) endif() # ---[ Main build +add_subdirectory(torch/headeronly) # headeronly headers add_subdirectory(c10) add_subdirectory(caffe2) diff --git a/c10/BUCK.oss b/c10/BUCK.oss index 4b2cbd049b85f..4ec4ab5beabb4 100644 --- a/c10/BUCK.oss +++ b/c10/BUCK.oss @@ -37,8 +37,6 @@ cxx_library( ), exported_linker_flags = [], exported_preprocessor_flags = [ - '-DC10_USING_CUSTOM_GENERATED_MACROS', - '-DC10_USE_GLOG', '-DC10_USE_MINIMAL_GLOG', '-DC10_MOBILE', '-fexceptions', diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 8e9d267352dd2..f82e460cafc31 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -18,16 +18,12 @@ else() set(C10_LIB c10) endif() - # ---[ Configure macro file. - set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in - set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in - set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in - set(C10_USE_NUMA ${USE_NUMA}) - set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) - set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) - configure_file( - ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in - ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h) +set(C10_USE_GFLAGS ${USE_GFLAGS}) # also used in torch/headeronly +set(C10_USE_GLOG ${USE_GLOG}) # also used in torch/headeronly +set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # also used in torch/headeronly +set(C10_USE_NUMA ${USE_NUMA}) # also used in torch/headeronly +set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) # also used in torch/headeronly +set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) # also used in torch/headeronly # Note: if you want to add ANY dependency to the c10 library, make sure you # check with the core PyTorch developers as the dependency will be @@ -94,6 +90,8 @@ if(NOT BUILD_LIBTORCHLESS) if(C10_USE_GLOG) target_link_libraries(c10 PUBLIC glog::glog) endif() + + target_link_libraries(c10 PUBLIC headeronly) target_link_libraries(c10 PRIVATE fmt::fmt-header-only) target_link_libraries(c10 PRIVATE nlohmann) target_link_libraries(c10 PRIVATE moodycamel) @@ -170,8 +168,6 @@ endif() install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR} DESTINATION include FILES_MATCHING PATTERN "*.h") -install(FILES ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h - DESTINATION include/c10/macros) if(MSVC AND C10_BUILD_SHARED_LIBS) install(FILES $ DESTINATION lib OPTIONAL) diff --git a/c10/macros/Export.h b/c10/macros/Export.h index b013910902b26..3d91266102613 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -2,7 +2,7 @@ #define C10_MACROS_EXPORT_H_ #ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include +#include #endif // C10_USING_CUSTOM_GENERATED_MACROS #include diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 6b51a39f2a943..55a79ee67430c 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -19,7 +19,7 @@ // file. #ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include +#include #endif // C10_USING_CUSTOM_GENERATED_MACROS #include diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index 129b2b1e05702..d5809d36687d7 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -1,13 +1,13 @@ def define_targets(rules): rules.cc_library( name = "macros", - srcs = [":cmake_macros_h"], hdrs = [ "Macros.h", # Despite the documentation in Macros.h, Export.h is included # directly by many downstream files. Thus, we declare it as a # public header in this file. "Export.h", + "cmake_macros.h", ], linkstatic = True, local_defines = ["C10_BUILD_MAIN_LIB"], @@ -17,22 +17,6 @@ def define_targets(rules): ], ) - rules.cmake_configure_file( - name = "cmake_macros_h", - src = "cmake_macros.h.in", - out = "cmake_macros.h", - definitions = [ - "C10_BUILD_SHARED_LIBS", - "C10_USE_MSVC_STATIC_RUNTIME", - ] + rules.select({ - "//c10:using_gflags": ["C10_USE_GFLAGS"], - "//conditions:default": [], - }) + rules.select({ - "//c10:using_glog": ["C10_USE_GLOG"], - "//conditions:default": [], - }), - ) - rules.filegroup( name = "headers", srcs = rules.glob( diff --git a/c10/macros/cmake_macros.h b/c10/macros/cmake_macros.h new file mode 100644 index 0000000000000..4358f6906c972 --- /dev/null +++ b/c10/macros/cmake_macros.h @@ -0,0 +1,5 @@ +// This file exists for backwards compatibility and has been moved to +// torch/headeronly/macros/cmake_macros.h.in. No end user library should be +// including this file directly anyway (cuz they should be including +// Macros.h instead). +#include diff --git a/c10/ovrsource_defs.bzl b/c10/ovrsource_defs.bzl index 4abf8b0014dea..aafe5a4de8c42 100644 --- a/c10/ovrsource_defs.bzl +++ b/c10/ovrsource_defs.bzl @@ -73,8 +73,7 @@ def define_c10_ovrsource(name, is_mobile): ], }), exported_deps = [ - "//xplat/caffe2/torch/headeronly:torch_headeronly", - ":ovrsource_c10_cmake_macros.h", + "//xplat/caffe2/torch/headeronly:torch_headeronly_ovrsource", "//arvr/third-party/gflags:gflags", "//third-party/cpuinfo:cpuinfo", "//third-party/fmt:fmt", @@ -83,55 +82,6 @@ def define_c10_ovrsource(name, is_mobile): ) def define_ovrsource_targets(): - common_c10_cmake_defines = [ - ("#cmakedefine C10_BUILD_SHARED_LIBS", ""), - ("#cmakedefine C10_USE_NUMA", ""), - ("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""), - ("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""), - ] - - mobile_c10_cmake_defines = [ - ("#cmakedefine C10_USE_GLOG", ""), - ("#cmakedefine C10_USE_GFLAGS", ""), - ] - - non_mobile_c10_cmake_defines = [ - ("#cmakedefine C10_USE_GLOG", "#define C10_USE_GLOG 1"), - ("#cmakedefine C10_USE_GFLAGS", "#define C10_USE_GFLAGS 1"), - ] - - gen_cmake_header( - src = "macros/cmake_macros.h.in", - defines = common_c10_cmake_defines + mobile_c10_cmake_defines, - header = "c10/macros/cmake_macros.h", - prefix = "ovrsource_c10_mobile_", - ) - - gen_cmake_header( - src = "macros/cmake_macros.h.in", - defines = common_c10_cmake_defines + non_mobile_c10_cmake_defines, - header = "c10/macros/cmake_macros.h", - prefix = "ovrsource_c10_non_mobile_", - ) - - oxx_static_library( - name = "ovrsource_c10_cmake_macros.h", - compatible_with = [ - "ovr_config//os:android", - "ovr_config//os:iphoneos", - "ovr_config//os:linux", - "ovr_config//os:macos", - "ovr_config//os:windows", - ], - deps = select({ - "ovr_config//os:android": [":ovrsource_c10_mobile_cmake_macros.h"], - "ovr_config//os:iphoneos": [":ovrsource_c10_mobile_cmake_macros.h"], - "ovr_config//os:linux": [":ovrsource_c10_non_mobile_cmake_macros.h"], - "ovr_config//os:macos": [":ovrsource_c10_non_mobile_cmake_macros.h"], - "ovr_config//os:windows": [":ovrsource_c10_non_mobile_cmake_macros.h"], - }), - ) - c10_cuda_macros = gen_cmake_header( src = "cuda/impl/cuda_cmake_macros.h.in", defines = [ diff --git a/tools/bazel.bzl b/tools/bazel.bzl index cd263ba4d3241..9b662859adb46 100644 --- a/tools/bazel.bzl +++ b/tools/bazel.bzl @@ -2,7 +2,7 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda_enabled") load("@rules_python//python:defs.bzl", "py_binary", "py_library") load("@pip_deps//:requirements.bzl", "requirement") -load("@pytorch//c10/macros:cmake_configure_file.bzl", "cmake_configure_file") +load("@pytorch//torch/headeronly/macros:cmake_configure_file.bzl", "cmake_configure_file") load("@pytorch//tools/config:defs.bzl", "if_cuda") def _genrule(**kwds): diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index bc92f97b3956e..8d761068d1e62 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -29,7 +29,6 @@ endif() set(LIBSHM_SRCDIR ${TORCH_SRC_DIR}/lib/${LIBSHM_SUBDIR}) add_subdirectory(${LIBSHM_SRCDIR}) - # Generate files set(TOOLS_PATH "${TORCH_ROOT}/tools") diff --git a/torch/headeronly/BUCK.oss b/torch/headeronly/BUCK.oss new file mode 100644 index 0000000000000..2b8d77e597a68 --- /dev/null +++ b/torch/headeronly/BUCK.oss @@ -0,0 +1,26 @@ +load("//tools/build_defs:glob_defs.bzl", "subdir_glob") + +cxx_library( + name = "torch_headeronly", + header_namespace = "torch/headeronly", + exported_deps = [], + compiler_flags = [ + "-Werror", + "-Wno-global-constructors", + ], + exported_headers = subdir_glob( + [ + ("", "**/*.h"), + ], + ), + exported_linker_flags = [], + exported_preprocessor_flags = [ + '-DC10_USING_CUSTOM_GENERATED_MACROS', + '-DC10_USE_GLOG', + ], + link_whole = True, + platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], + preprocessor_flags = ['-DC10_BUILD_MAIN_LIB'], + reexport_all_header_dependencies = True, + visibility = ['PUBLIC'], +) diff --git a/torch/headeronly/BUILD.bazel b/torch/headeronly/BUILD.bazel index f4a27fac1f7f6..030651b120436 100644 --- a/torch/headeronly/BUILD.bazel +++ b/torch/headeronly/BUILD.bazel @@ -1,9 +1,5 @@ load("@rules_cc//cc:defs.bzl", "cc_library") +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") -cc_library( - name = "torch_headeronly", - hdrs = glob([ - "**/*.h" - ]), - visibility = ["//visibility:public"], -) +define_targets(rules = rules) diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt new file mode 100644 index 0000000000000..08ad713ca8452 --- /dev/null +++ b/torch/headeronly/CMakeLists.txt @@ -0,0 +1,34 @@ +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + +project(headeronly CXX) + +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Main build file for torch/headeronly, except there's no build cuz this lib is header-only! + +# ---[ Configure macro file. +set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in +set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in +set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in +set(C10_USE_NUMA ${USE_NUMA}) # used in cmake_macros.h.in +set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) # used in cmake_macros.h.in +set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) # used in cmake_macros.h.in +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in + ${CMAKE_BINARY_DIR}/torch/headeronly/macros/cmake_macros.h) + +file(GLOB HEADERONLY_HEADERS + *.h + macros/*.h +) + +add_library(headeronly INTERFACE ${HEADERONLY_HEADERS}) + +install(FILES ${CMAKE_BINARY_DIR}/torch/headeronly/macros/cmake_macros.h + DESTINATION include/torch/headeronly/macros) + +if(NOT BUILD_LIBTORCHLESS) + # ---[ Installation copied from c10/CMakeLists.txt + install(TARGETS headeronly EXPORT Caffe2Targets DESTINATION lib) +endif() diff --git a/torch/headeronly/build.bzl b/torch/headeronly/build.bzl new file mode 100644 index 0000000000000..6ec9a843e8848 --- /dev/null +++ b/torch/headeronly/build.bzl @@ -0,0 +1,11 @@ +def define_targets(rules): + rules.cc_library( + name = "torch_headeronly", + hdrs = rules.glob([ + "**/*.h" + ]), + visibility = ["//visibility:public"], + deps = [ + "//torch/headeronly/macros", + ], + ) diff --git a/torch/headeronly/macros/BUILD.bazel b/torch/headeronly/macros/BUILD.bazel new file mode 100644 index 0000000000000..d1a0db360d230 --- /dev/null +++ b/torch/headeronly/macros/BUILD.bazel @@ -0,0 +1,4 @@ +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") + +define_targets(rules = rules) diff --git a/torch/headeronly/macros/build.bzl b/torch/headeronly/macros/build.bzl new file mode 100644 index 0000000000000..5217c2f7d37d6 --- /dev/null +++ b/torch/headeronly/macros/build.bzl @@ -0,0 +1,28 @@ +def define_targets(rules): + rules.cc_library( + name = "macros", + srcs = [":cmake_macros_h"], + hdrs = [ + # Following the example from c10 + "Export.h", + ], + linkstatic = True, + local_defines = ["C10_BUILD_MAIN_LIB"], + visibility = ["//visibility:public"], + ) + + rules.cmake_configure_file( + name = "cmake_macros_h", + src = "cmake_macros.h.in", + out = "cmake_macros.h", + definitions = [ + "C10_BUILD_SHARED_LIBS", + "C10_USE_MSVC_STATIC_RUNTIME", + ] + rules.select({ + "//c10:using_gflags": ["C10_USE_GFLAGS"], + "//conditions:default": [], + }) + rules.select({ + "//c10:using_glog": ["C10_USE_GLOG"], + "//conditions:default": [], + }), + ) diff --git a/c10/macros/cmake_configure_file.bzl b/torch/headeronly/macros/cmake_configure_file.bzl similarity index 100% rename from c10/macros/cmake_configure_file.bzl rename to torch/headeronly/macros/cmake_configure_file.bzl diff --git a/c10/macros/cmake_macros.h.in b/torch/headeronly/macros/cmake_macros.h.in similarity index 80% rename from c10/macros/cmake_macros.h.in rename to torch/headeronly/macros/cmake_macros.h.in index 76c185b55236c..e624221202dfe 100644 --- a/c10/macros/cmake_macros.h.in +++ b/torch/headeronly/macros/cmake_macros.h.in @@ -2,7 +2,7 @@ #define C10_MACROS_CMAKE_MACROS_H_ // Automatically generated header file for the C10 library. -// Do not include this file directly. Instead, include c10/macros/Macros.h. +// Do not include this file directly. Instead, include torch/headeronly/macros/Macros.h. #cmakedefine C10_BUILD_SHARED_LIBS #cmakedefine C10_USE_GLOG diff --git a/torch/headeronly/ovrsource_defs.bzl b/torch/headeronly/ovrsource_defs.bzl index 55e1947b5e76c..5ba9b593c2974 100644 --- a/torch/headeronly/ovrsource_defs.bzl +++ b/torch/headeronly/ovrsource_defs.bzl @@ -1,3 +1,4 @@ +load("//arvr/tools/build_defs:genrule_utils.bzl", "gen_cmake_header") load("//arvr/tools/build_defs:oxx.bzl", "oxx_static_library") cpu_supported_platforms = [ @@ -18,29 +19,77 @@ def define_torch_headeronly_ovrsource(name, is_mobile): oxx_static_library( name = name, - srcs = [] + srcs = [], compatible_with = cpu_supported_platforms, compiler_flags = select({ "DEFAULT": [], }), - include_directories = [".."], - preprocessor_flags = [], + preprocessor_flags = ["-DC10_BUILD_MAIN_LIB=1",], fbobjc_compiler_flags = [], - public_include_directories = [".."], + public_include_directories = ["../.."], public_preprocessor_flags = pp_flags, public_raw_headers = native.glob([ "macros/*.h", ]), reexport_all_header_dependencies = False, visibility = [ - "//xplat/caffe2/torch/headeronly:torch_headeronly", + "//xplat/caffe2/torch/headeronly:torch_headeronly_ovrsource", + ], + exported_deps = [ + ":ovrsource_torch_headeronly_cmake_macros.h", + ], + ) + +def define_ovrsource_targets(): + common_c10_cmake_defines = [ + ("#cmakedefine C10_BUILD_SHARED_LIBS", ""), + ("#cmakedefine C10_USE_NUMA", ""), + ("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""), + ("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""), + ] + + mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", ""), + ("#cmakedefine C10_USE_GFLAGS", ""), + ] + + non_mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", "#define C10_USE_GLOG 1"), + ("#cmakedefine C10_USE_GFLAGS", "#define C10_USE_GFLAGS 1"), + ] + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + mobile_c10_cmake_defines, + header = "torch/headeronly/macros/cmake_macros.h", + prefix = "ovrsource_torch_headeronly_mobile_", + ) + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + non_mobile_c10_cmake_defines, + header = "torch/headeronly/macros/cmake_macros.h", + prefix = "ovrsource_torch_headeronly_non_mobile_", + ) + + oxx_static_library( + name = "ovrsource_torch_headeronly_cmake_macros.h", + compatible_with = [ + "ovr_config//os:android", + "ovr_config//os:iphoneos", + "ovr_config//os:linux", + "ovr_config//os:macos", + "ovr_config//os:windows", ], deps = select({ - "DEFAULT": [], + "ovr_config//os:android": [":ovrsource_torch_headeronly_mobile_cmake_macros.h"], + "ovr_config//os:iphoneos": [":ovrsource_torch_headeronly_mobile_cmake_macros.h"], + "ovr_config//os:linux": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], + "ovr_config//os:macos": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], + "ovr_config//os:windows": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], }), ) -def define_ovrsource_targets(): oxx_static_library( name = "torch_headeronly_ovrsource", compatible_with = cpu_supported_platforms, From b86d5cef68d56f3924dc199424e65904a32d0743 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 14 Jul 2025 11:35:27 -0700 Subject: [PATCH 070/457] [dynamo][tensor] Skip HASATTR attribute on tensor guards (#158215) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158215 Approved by: https://github.com/StrongerXi --- torch/_dynamo/variables/tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 79471bda50ae9..62d0542dcab04 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -316,17 +316,18 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): real_value = getattr(_input_associated_real_value, name) attr_source = AttrSource(self.source, name) - install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) # Typically we'd want to use variable builder here # but unfortunately id(real_value.__self__) is not id() if is_bound_tensor_method(real_value): + # No need to install the guard because its a bound tensor method from .misc import GetAttrVariable return GetAttrVariable( self, name, source=attr_source, py_type=type(real_value) ) + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) return VariableTracker.build(tx, real_value, attr_source) def method_attr_ndim(self, tx): From dbf7d421dabced2335d17c7d7e57c1770f2f12c0 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Mon, 14 Jul 2025 13:06:31 -0700 Subject: [PATCH 071/457] [BE][testing] fix aot_inductor_package internally (#158270) Summary: We have internal test failure for several aot_inductor_package tests. It looks like we're translating args like: ``` -Wl,--script=/home/slarsen/local/fbsource2/buck-out/v2/gen/fbcode/7ce8f48f92bc4ee6/caffe2/test/inductor/__aot_inductor_package__/aot_inductor_package#link-tree/torch/_inductor/script.ld ``` To: ``` -Wl,--script=/home/slarsen/local/fbsource2/buck-out/v2/gen/fbcode/7ce8f48f92bc4ee6/caffe2/test/inductor/__aot_inductor_package__/aot_inductor_package#link-tree/torch/_inductor//tmp/jZMktZ/tmpsqoxb_cq/data/aotinductor/model/script.ld ``` This PR changes to strings like: ``` -Wl,--script=/tmp/jZMktZ/tmpsqoxb_cq/data/aotinductor/model/script.ld ``` Test Plan: `buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:aot_inductor_package --run-disabled` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158270 Approved by: https://github.com/desertfire --- .../inductor/aoti_package/model_package_loader.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index dacdc9eac3882..cca993406f7f3 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -136,15 +136,12 @@ std::tuple get_cpp_compile_command( } std::string passthrough_parameters_args; + std::regex script_regex(R"(--script=[^,]*script\.ld)"); + std::string replacement = + "--script=" + target_dir + k_separator + "script.ld"; for (auto& arg : compile_options["passthrough_args"]) { - std::string arg_str = arg.get(); - std::string target = "script.ld"; - std::string replacement = target_dir; - replacement.append(k_separator).append(target); - size_t pos = arg_str.find(target); - if (pos != std::string::npos) { - arg_str.replace(pos, target.length(), replacement); - } + std::string arg_str = + std::regex_replace(arg.get(), script_regex, replacement); passthrough_parameters_args += arg_str + " "; } From 19625daf889f0a6192a76e200205817e3ee27f26 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 15 Jul 2025 20:16:50 +0000 Subject: [PATCH 072/457] [1/n] Remove references to TorchScript in PyTorch docs (#158305) Summary: Removed jit_language_reference_v2.md Test Plan: CI Rollback Plan: Differential Revision: D78308009 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158305 Approved by: https://github.com/jingsh, https://github.com/svekars --- docs/source/jit_language_reference_v2.md | 1831 +--------------------- 1 file changed, 4 insertions(+), 1827 deletions(-) diff --git a/docs/source/jit_language_reference_v2.md b/docs/source/jit_language_reference_v2.md index 12bd2a18a201c..40da0740963ba 100644 --- a/docs/source/jit_language_reference_v2.md +++ b/docs/source/jit_language_reference_v2.md @@ -25,1830 +25,7 @@ # TorchScript Language Reference -This reference manual describes the syntax and core semantics of the TorchScript language. -TorchScript is a statically typed subset of the Python language. This document explains the supported features of -Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in -this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to -represent neural network models in PyTorch. - -```{contents} -:depth: 1 -:local: true -``` - -(type-system)= - -## Terminology - -This document uses the following terminologies: - -```{eval-rst} -.. list-table:: - :widths: 25 25 - :header-rows: 1 - - * - Pattern - - Notes - * - ``::=`` - - Indicates that the given symbol is defined as. - * - ``" "`` - - Represents real keywords and delimiters that are part of the syntax. - * - ``A | B`` - - Indicates either A or B. - * - ``( )`` - - Indicates grouping. - * - ``[]`` - - Indicates optional. - * - ``A+`` - - Indicates a regular expression where term A is repeated at least once. - * - ``A*`` - - Indicates a regular expression where term A is repeated zero or more times. -``` - -## Type System - -TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express -neural net models. - -### TorchScript Types - -The TorchScript type system consists of `TSType` and `TSModuleType` as defined below. - -``` -TSAllType ::= TSType | TSModuleType -TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType -``` - -`TSType` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. -`TSType` refers to any of the following: - -- Meta Types, e.g., `Any` -- Primitive Types, e.g., `int`, `float`, and `str` -- Structural Types, e.g., `Optional[int]` or `List[MyClass]` -- Nominal Types (Python classes), e.g., `MyClass` (user-defined), `torch.tensor` (built-in) - -`TSModuleType` represents `torch.nn.Module` and its subclasses. It is treated differently from `TSType` because its type schema is inferred partly from the object instance and partly from the class definition. -As such, instances of a `TSModuleType` may not follow the same static type schema. `TSModuleType` cannot be used as a TorchScript type annotation or be composed with `TSType` for type safety considerations. - -### Meta Types - -Meta types are so abstract that they are more like type constraints than concrete types. -Currently TorchScript defines one meta-type, `Any`, that represents any TorchScript type. - -#### `Any` Type - -The `Any` type represents any TorchScript type. `Any` specifies no type constraints, thus there is no type-checking on `Any`. -As such it can be bound to any Python or TorchScript data types (e.g., `int`, TorchScript `tuple`, or an arbitrary Python class that is not scripted). - -``` -TSMetaType ::= "Any" -``` - -Where: - -- `Any` is the Python class name from the typing module. Therefore, to use the `Any` type, you must import it from `typing` (e.g., `from typing import Any`). -- Since `Any` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on `Any` is limited. - -#### Operators Supported for `Any` Type - -- Assignment to data of `Any` type. -- Binding to parameter or return of `Any` type. -- `x is`, `x is not` where `x` is of `Any` type. -- `isinstance(x, Type)` where `x` is of `Any` type. -- Data of `Any` type is printable. -- Data of `List[Any]` type may be sortable if the data is a list of values of the same type `T` and that `T` supports comparison operators. - -**Compared to Python** - -`Any` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the -`Object` class in Python. However, `Any` only supports a subset of the operators and methods that are supported by `Object`. - -#### Design Notes - -When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described -by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary -scripting failures. `Any` is introduced to describe the type of the data where precise static types are not necessary for compilation. - -**Example 1** - -This example illustrates how `Any` can be used to allow the second element of the tuple parameter to be of any type. This is possible -because `x[1]` is not involved in any computation that requires knowing its precise type. - -```{eval-rst} -.. testcode:: - - import torch - - from typing import Tuple - from typing import Any - - @torch.jit.export - def inc_first_element(x: Tuple[int, Any]): - return (x[0]+1, x[1]) - - m = torch.jit.script(inc_first_element) - print(m((1,2.0))) - print(m((1,(100,200)))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - (2, 2.0) - (2, (100, 200)) -``` - -The second element of the tuple is of `Any` type, thus can bind to multiple types. -For example, `(1, 2.0)` binds a float type to `Any` as in `Tuple[int, Any]`, -whereas `(1, (100, 200))` binds a tuple to `Any` in the second invocation. - -**Example 2** - -This example illustrates how we can use `isinstance` to dynamically check the type of the data that is annotated as `Any` type: - -```{eval-rst} -.. testcode:: - - import torch - from typing import Any - - def f(a:Any): - print(a) - return (isinstance(a, torch.Tensor)) - - ones = torch.ones([2]) - m = torch.jit.script(f) - print(m(ones)) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - 1 - 1 - [ CPUFloatType{2} ] - True -``` - -### Primitive Types - -Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined -type name. - -``` -TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" -``` - -### Structural Types - -Structural types are types that are structurally defined without a user-defined name (unlike nominal types), -such as `Future[int]`. Structural types are composable with any `TSType`. - -``` -TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | - TSOptional | TSUnion | TSFuture | TSRRef | TSAwait - -TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" -TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" -TSList ::= "List" "[" TSType "]" -TSOptional ::= "Optional" "[" TSType "]" -TSUnion ::= "Union" "[" (TSType ",")* TSType "]" -TSFuture ::= "Future" "[" TSType "]" -TSRRef ::= "RRef" "[" TSType "]" -TSAwait ::= "Await" "[" TSType "]" -TSDict ::= "Dict" "[" KeyType "," TSType "]" -KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" -``` - -Where: - -- `Tuple`, `List`, `Optional`, `Union`, `Future`, `Dict` represent Python type class names that are defined in the module `typing`. To use these type names, you must import them from `typing` (e.g., `from typing import Tuple`). -- `namedtuple` represents the Python class `collections.namedtuple` or `typing.NamedTuple`. -- `Future` and `RRef` represent the Python classes `torch.futures` and `torch.distributed.rpc`. -- `Await` represent the Python class `torch._awaits._Await` - -**Compared to Python** - -Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. - -**Example 1** - -This example uses `typing.NamedTuple` syntax to define a tuple: - -```{eval-rst} -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - - class MyTuple(NamedTuple): - first: int - second: int - - def inc(x: MyTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - t = MyTuple(first=1, second=2) - scripted_inc = torch.jit.script(inc) - print("TorchScript:", scripted_inc(t)) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - TorchScript: (2, 3) -``` - -**Example 2** - -This example uses `collections.namedtuple` syntax to define a tuple: - -```{eval-rst} -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - from collections import namedtuple - - _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) - _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) - - def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - m = torch.jit.script(inc) - print(inc(_UnannotatedNamedTuple(1,2))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - (2, 3) -``` - -**Example 3** - -This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type -classes from the `typing` module: - -```python -import torch - -# ERROR: Tuple not recognized because not imported from typing -@torch.jit.export -def inc(x: Tuple[int, int]): - return (x[0]+1, x[1]+1) - -m = torch.jit.script(inc) -print(m((1,2))) -``` - -Running the above code yields the following scripting error: - -```python -File "test-tuple.py", line 5, in - def inc(x: Tuple[int, int]): -NameError: name 'Tuple' is not defined -``` - -The remedy is to add the line `from typing import Tuple` to the beginning of the code. - -### Nominal Types - -Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom -name and are compared using class names. Nominal classes are further classified into the following categories: - -``` -TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum -``` - -Among them, `TSCustomClass` and `TSEnum` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. - -### Built-in Class - -Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). -TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or -attributes of its Python class definition. - -``` -TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | - "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... -TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | - "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor -``` - -#### Special Note on torch.nn.ModuleList and torch.nn.ModuleDict - -Although `torch.nn.ModuleList` and `torch.nn.ModuleDict` are defined as a list and dictionary in Python, -they behave more like tuples in TorchScript: - -- In TorchScript, instances of `torch.nn.ModuleList` or `torch.nn.ModuleDict` are immutable. -- Code that iterates over `torch.nn.ModuleList` or `torch.nn.ModuleDict` is completely unrolled so that elements of `torch.nn.ModuleList` or keys of `torch.nn.ModuleDict` can be of different subclasses of `torch.nn.Module`. - -**Example** - -The following example highlights the use of a few built-in Torchscript classes (`torch.*`): - -```python -import torch - -@torch.jit.script -class A: - def __init__(self): - self.x = torch.rand(3) - - def f(self, y: torch.device): - return self.x.to(device=y) - -def g(): - a = A() - return a.f(torch.device("cpu")) - -script_g = torch.jit.script(g) -print(script_g.graph) -``` - -### Custom Class - -Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. - -``` -TSClassDef ::= [ "@torch.jit.script" ] - "class" ClassName [ "(object)" ] ":" - MethodDefinition | - [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] - MethodDefinition -``` - -Where: - -- Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. -- Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the `__init__()` method. -- Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). -- `MethodDefinition` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). -- `torch.jit.ignore` and `torch.jit.unused` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. - -**Compared to Python** - -TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: - -- Do not support class attributes. -- Do not support subclassing except for subclassing an interface type or object. -- Do not support method overloading. -- Must initialize all its instance attributes in `__init__()`; this is because TorchScript constructs a static schema of the class by inferring attribute types in `__init__()`. -- Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. - -**Example 1** - -Python classes can be used in TorchScript if they are annotated with `@torch.jit.script`, similar to how a TorchScript function would be declared: - -```python -@torch.jit.script -class MyClass: - def __init__(self, x: int): - self.x = x - - def inc(self, val: int): - self.x += val -``` - -**Example 2** - -A TorchScript custom class type must "declare" all its instance attributes by assignments in `__init__()`. If an instance attribute is not defined in `__init__()` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: - -```python -import torch - -@torch.jit.script -class foo: - def __init__(self): - self.y = 1 - -# ERROR: self.x is not defined in __init__ -def assign_x(self): - self.x = torch.rand(2, 3) -``` - -The class will fail to compile and issue the following error: - -``` -RuntimeError: -Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: -def assign_x(self): - self.x = torch.rand(2, 3) - ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE -``` - -**Example 3** - -In this example, a TorchScript custom class defines a class variable name, which is not allowed: - -```python -import torch - -@torch.jit.script -class MyClass(object): - name = "MyClass" - def __init__(self, x: int): - self.x = x - -def fn(a: MyClass): - return a.name -``` - -It leads to the following compile-time error: - -``` -RuntimeError: -'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: - File "test-class2.py", line 10 -def fn(a: MyClass): - return a.name - ~~~~~~ <--- HERE -``` - -### Enum Type - -Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. - -``` -TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" - ( MemberIdentifier "=" Value )+ - ( MethodDefinition )* -``` - -Where: - -- Value must be a TorchScript literal of type `int`, `float`, or `str`, and must be of the same TorchScript type. -- `TSEnumType` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted `Enum` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. - -**Compared to Python** - -- TorchScript supports only `enum.Enum`. It does not support other variations such as `enum.IntEnum`, `enum.Flag`, `enum.IntFlag`, and `enum.auto`. -- Values of TorchScript enum members must be of the same type and can only be `int`, `float`, or `str` types, whereas Python enum members can be of any type. -- Enums containing methods are ignored in TorchScript. - -**Example 1** - -The following example defines the class `Color` as an `Enum` type: - -```python -import torch -from enum import Enum - -class Color(Enum): - RED = 1 - GREEN = 2 - -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - -m = torch.jit.script(enum_fn) - -print("Eager: ", enum_fn(Color.RED, Color.GREEN)) -print("TorchScript: ", m(Color.RED, Color.GREEN)) -``` - -**Example 2** - -The following example shows the case of restricted enum subclassing, where `BaseColor` does not define any member, thus can be subclassed by `Color`: - -```python -import torch -from enum import Enum - -class BaseColor(Enum): - def foo(self): - pass - -class Color(BaseColor): - RED = 1 - GREEN = 2 - -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - -m = torch.jit.script(enum_fn) - -print("TorchScript: ", m(Color.RED, Color.GREEN)) -print("Eager: ", enum_fn(Color.RED, Color.GREEN)) -``` - -### TorchScript Module Class - -`TSModuleType` is a special class type that is inferred from object instances that are created outside TorchScript. `TSModuleType` is named by the Python class of the object instance. The `__init__()` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. - -The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from `__init__()` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. - -In this sense, `TSModuleType` is not really a static type. Therefore, for type safety considerations, `TSModuleType` cannot be used in a TorchScript type annotation or be composed with `TSType`. - -### Module Instance Class - -TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to `forward`). The Python module class is treated as a module instance class, so the `__init__()` method of the Python module class is not subject to the type-checking rules of TorchScript. - -``` -TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" - ClassBodyDefinition -``` - -Where: - -- `forward()` and other methods decorated with `@torch.jit.export` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. - -Unlike custom classes, only the forward method and other methods decorated with `@torch.jit.export` of the module type need to be compilable. Most notably, `__init__()` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into `torch.jit.script(ModuleObj)`. - -**Example 1** - -This example illustrates a few features of module types: - -- The `TestModule` instance is created outside the scope of TorchScript (i.e., before invoking `torch.jit.script`). -- `__init__()` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the `__init__()` method of an instance class cannot be invoked in TorchScript code. Because `TestModule` instances are instantiated in Python, in this example, `TestModule(2.0)` and `TestModule(2)` create two instances with different types for its data attributes. `self.x` is of type `float` for `TestModule(2.0)`, whereas `self.y` is of type `int` for `TestModule(2.0)`. -- TorchScript automatically compiles other methods (e.g., `mul()`) invoked by methods annotated via `@torch.jit.export` or `forward()` methods. -- Entry-points to a TorchScript program are either `forward()` of a module type, functions annotated as `torch.jit.script`, or methods annotated as `torch.jit.export`. - -```{eval-rst} -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, inc: int): - return self.x + inc - - m = torch.jit.script(TestModule(1)) - print(f"First instance: {m(3)}") - - m = torch.jit.script(TestModule(torch.ones([5]))) - print(f"Second instance: {m(3)}") -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - First instance: 4 - Second instance: tensor([4., 4., 4., 4., 4.]) -``` - -**Example 2** - -The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of `TestModule` inside the scope of TorchScript: - -```{eval-rst} -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, x: int): - return self.x + x - - class MyModel: - def __init__(self, v: int): - self.val = v - - @torch.jit.export - def doSomething(self, val: int) -> int: - # error: should not invoke the constructor of module type - myModel = TestModule(self.val) - return myModel(val) - - # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError - # RuntimeError: Could not get name of python class object -``` - -(type-annotation)= - -## Type Annotation - -Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or -instance data attribute has a static type, and every function and method has a statically typed signature. - -### When to Annotate Types - -In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to -methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type -may be too restrictive, e.g., `x` being inferred as `NoneType` through assignment `x = None`, whereas `x` is actually used as an `Optional`. In such -cases, type annotations may be needed to overwrite auto inference, e.g., `x: Optional[int] = None`. Note that it is always safe to type annotate a local variable -or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. - -When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a -default type of `TensorType`, `List[TensorType]`, or `Dict[str, TensorType]`. - -### Annotate Function Signature - -Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type `TensorType`. - -TorchScript supports two styles for method and function signature type annotation: - -- **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of `TensorType`), or allows the return type to be left unannotated (whose type will be automatically inferred). - -``` -Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" - FuncOrMethodBody -ParamAnnot ::= Identifier [ ":" TSType ] "," -ReturnAnnot ::= "->" TSType -``` - -Note that when using Python3 style, the type `self` is automatically inferred and should not be annotated. - -- **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. - -``` -MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] -ParamAnnot ::= TSType "," -ReturnAnnot ::= "->" TSType -``` - -**Example 1** - -In this example: - -- `a` is not annotated and assumes the default type of `TensorType`. -- `b` is annotated as type `int`. -- The return type is not annotated and is automatically inferred as type `TensorType` (based on the type of the value being returned). - -```python -import torch - -def f(a, b: int): - return a+b - -m = torch.jit.script(f) -print("TorchScript:", m(torch.ones([6]), 100)) -``` - -**Example 2** - -The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of -them assume the default type. - -```python -import torch - -def f(a, b): - # type: (torch.Tensor, int) → torch.Tensor - return a+b - -m = torch.jit.script(f) -print("TorchScript:", m(torch.ones([6]), 100)) -``` - -### Annotate Variables and Data Attributes - -In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. -Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as `None` or `TensorType`), then they may need to be explicitly -type annotated as a *wider* type such as `Optional[int]` or `Any`. - -#### Local Variables - -Local variables can be annotated according to Python3 typing module annotation rules, i.e., - -``` -LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr -``` - -In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables -that may be associated with different concrete types. Typical multi-types include `Optional[T]` and `Any`. - -**Example** - -```python -import torch - -def f(a, setVal: bool): - value: Optional[torch.Tensor] = None - if setVal: - value = a - return value - -ones = torch.ones([6]) -m = torch.jit.script(f) -print("TorchScript:", m(ones, True), m(ones, False)) -``` - -#### Instance Data Attributes - -For `ModuleType` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final -via `Final`. - -``` -"class" ClassIdentifier "(torch.nn.Module):" -InstanceAttrIdentifier ":" ["Final("] TSType [")"] -... -``` - -Where: - -- `InstanceAttrIdentifier` is the name of an instance attribute. -- `Final` indicates that the attribute cannot be re-assigned outside of `__init__` or overridden in subclasses. - -**Example** - -```python -import torch - -class MyModule(torch.nn.Module): - offset_: int - -def __init__(self, offset): - self.offset_ = offset - -... -``` - -### Type Annotation APIs - -#### `torch.jit.annotate(T, expr)` - -This API annotates type `T` to an expression `expr`. This is often used when the default type of an expression is not the type intended by the programmer. -For instance, an empty list (dictionary) has the default type of `List[TensorType]` (`Dict[TensorType, TensorType]`), but sometimes it may be used to initialize -a list of some other types. Another common use case is for annotating the return type of `tensor.tolist()`. Note, however, that it cannot be used to annotate -the type of a module attribute in `__init__`; `torch.jit.Attribute` should be used for this instead. - -**Example** - -In this example, `[]` is declared as a list of integers via `torch.jit.annotate` (instead of assuming `[]` to be the default type of `List[TensorType]`): - -```python -import torch -from typing import List - -def g(l: List[int], val: int): - l.append(val) - return l - -def f(val: int): - l = g(torch.jit.annotate(List[int], []), val) - return l - -m = torch.jit.script(f) -print("Eager:", f(3)) -print("TorchScript:", m(3)) -``` - -See {meth}`torch.jit.annotate` for more information. - -### Type Annotation Appendix - -#### TorchScript Type System Definition - -``` -TSAllType ::= TSType | TSModuleType -TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType - -TSMetaType ::= "Any" -TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" - -TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | - TSUnion | TSFuture | TSRRef | TSAwait -TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" -TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" -TSList ::= "List" "[" TSType "]" -TSOptional ::= "Optional" "[" TSType "]" -TSUnion ::= "Union" "[" (TSType ",")* TSType "]" -TSFuture ::= "Future" "[" TSType "]" -TSRRef ::= "RRef" "[" TSType "]" -TSAwait ::= "Await" "[" TSType "]" -TSDict ::= "Dict" "[" KeyType "," TSType "]" -KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" - -TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum -TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| - "torch.dtype" | "torch.nn.ModuleList" | - "torch.nn.ModuleDict" | ... -TSTensor ::= "torch.tensor" and subclasses -``` - -#### Unsupported Typing Constructs - -TorchScript does not support all features and types of the Python3 [typing](https://docs.python.org/3/library/typing.html#module-typing) module. -Any functionality from the [typing](https://docs.python.org/3/library/typing.html#module-typing) module that is not explicitly specified in this -documentation is unsupported. The following table summarizes `typing` constructs that are either unsupported or supported with restrictions in TorchScript. - -```{eval-rst} -============================= ================ - Item Description ------------------------------ ---------------- -``typing.Any`` In development -``typing.NoReturn`` Not supported -``typing.Callable`` Not supported -``typing.Literal`` Not supported -``typing.ClassVar`` Not supported -``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. -``typing.AnyStr`` Not supported -``typing.overload`` In development -Type aliases Not supported -Nominal typing In development -Structural typing Not supported -NewType Not supported -Generics Not supported -============================= ================ -``` - -(expressions)= - -## Expressions - -The following section describes the grammar of expressions that are supported in TorchScript. -It is modeled after [the expressions chapter of the Python language reference](https://docs.python.org/3/reference/expressions.html). - -### Arithmetic Conversions - -There are a number of implicit type conversions that are performed in TorchScript: - -- A `Tensor` with a `float` or `int` data type can be implicitly converted to an instance of `FloatType` or `IntType` provided that it has a size of 0, does not have `require_grad` set to `True`, and will not require narrowing. -- Instances of `StringType` can be implicitly converted to `DeviceType`. -- The implicit conversion rules from the two bullet points above can be applied to instances of `TupleType` to produce instances of `ListType` with the appropriate contained type. - -Explicit conversions can be invoked using the `float`, `int`, `bool`, and `str` built-in functions -that accept primitive data types as arguments and can accept user-defined types if they implement -`__bool__`, `__str__`, etc. - -### Atoms - -Atoms are the most basic elements of expressions. - -``` -atom ::= identifier | literal | enclosure -enclosure ::= parenth_form | list_display | dict_display -``` - -#### Identifiers - -The rules that dictate what is a legal identifier in TorchScript are the same as -their [Python counterparts](https://docs.python.org/3/reference/lexical_analysis.html#identifiers). - -#### Literals - -``` -literal ::= stringliteral | integer | floatnumber -``` - -Evaluation of a literal yields an object of the appropriate type with the specific value -(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations -of identical literals may obtain the same object or distinct objects with the same value. -[stringliteral](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals), -[integer](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals), and -[floatnumber](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) -are defined in the same way as their Python counterparts. - -#### Parenthesized Forms - -``` -parenth_form ::= '(' [expression_list] ')' -``` - -A parenthesized expression list yields whatever the expression list yields. If the list contains at least one -comma, it yields a `Tuple`; otherwise, it yields the single expression inside the expression list. An empty -pair of parentheses yields an empty `Tuple` object (`Tuple[]`). - -#### List and Dictionary Displays - -``` -list_comprehension ::= expression comp_for -comp_for ::= 'for' target_list 'in' or_expr -list_display ::= '[' [expression_list | list_comprehension] ']' -dict_display ::= '{' [key_datum_list | dict_comprehension] '}' -key_datum_list ::= key_datum (',' key_datum)* -key_datum ::= expression ':' expression -dict_comprehension ::= key_datum comp_for -``` - -Lists and dicts can be constructed by either listing the container contents explicitly or by providing -instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension -is semantically equivalent to using a for loop and appending to an ongoing list. -Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the -enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list -are evaluated left-to-right. If a key is repeated in a `dict_display` that has a `key_datum_list`, the -resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. - -### Primaries - -``` -primary ::= atom | attributeref | subscription | slicing | call -``` - -#### Attribute References - -``` -attributeref ::= primary '.' identifier -``` - -The `primary` must evaluate to an object of a type that supports attribute references that have an attribute named -`identifier`. - -#### Subscriptions - -``` -subscription ::= primary '[' expression_list ']' -``` - -The `primary` must evaluate to an object that supports subscription. - -- If the primary is a `List`, `Tuple`, or `str`, the expression list must evaluate to an integer or slice. -- If the primary is a `Dict`, the expression list must evaluate to an object of the same type as the key type of the `Dict`. -- If the primary is a `ModuleList`, the expression list must be an `integer` literal. -- If the primary is a `ModuleDict`, the expression must be a `stringliteral`. - -#### Slicings - -A slicing selects a range of items in a `str`, `Tuple`, `List`, or `Tensor`. Slicings may be used as -expressions or targets in assignment or `del` statements. - -``` -slicing ::= primary '[' slice_list ']' -slice_list ::= slice_item (',' slice_item)* [','] -slice_item ::= expression | proper_slice -proper_slice ::= [expression] ':' [expression] [':' [expression] ] -``` - -Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an -object of type `Tensor`. - -#### Calls - -``` -call ::= primary '(' argument_list ')' -argument_list ::= args [',' kwargs] | kwargs -args ::= [arg (',' arg)*] -kwargs ::= [kwarg (',' kwarg)*] -kwarg ::= arg '=' expression -arg ::= identifier -``` - -The `primary` must desugar or evaluate to a callable object. All argument expressions are evaluated -before the call is attempted. - -### Power Operator - -``` -power ::= primary ['**' u_expr] -``` - -The power operator has the same semantics as the built-in pow function (not supported); it computes its -left argument raised to the power of its right argument. It binds more tightly than unary operators on the -left, but less tightly than unary operators on the right; i.e. `-2 ** -3 == -(2 ** (-3))`. The left and right -operands can be `int`, `float` or `Tensor`. Scalars are broadcast in the case of scalar-tensor/tensor-scalar -exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. - -### Unary and Arithmetic Bitwise Operations - -``` -u_expr ::= power | '-' power | '~' power -``` - -The unary `-` operator yields the negation of its argument. The unary `~` operator yields the bitwise inversion -of its argument. `-` can be used with `int`, `float`, and `Tensor` of `int` and `float`. -`~` can only be used with `int` and `Tensor` of `int`. - -### Binary Arithmetic Operations - -``` -m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr -a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr -``` - -The binary arithmetic operators can operate on `Tensor`, `int`, and `float`. For tensor-tensor ops, both arguments must -have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the -tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. -The `@` operator is for matrix multiplication and only operates on `Tensor` arguments. The multiplication operator -(`*`) can be used with a list and integer in order to get a result that is the original list repeated a certain -number of times. - -### Shifting Operations - -``` -shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr -``` - -These operators accept two `int` arguments, two `Tensor` arguments, or a `Tensor` argument and an `int` or -`float` argument. In all cases, a right shift by `n` is defined as floor division by `pow(2, n)`, and a left shift -by `n` is defined as multiplication by `pow(2, n)`. When both arguments are `Tensors`, they must have the same -shape. When one is a scalar and the other is a `Tensor`, the scalar is logically broadcast to match the size of -the `Tensor`. - -### Binary Bitwise Operations - -``` -and_expr ::= shift_expr | and_expr '&' shift_expr -xor_expr ::= and_expr | xor_expr '^' and_expr -or_expr ::= xor_expr | or_expr '|' xor_expr -``` - -The `&` operator computes the bitwise AND of its arguments, the `^` the bitwise XOR, and the `|` the bitwise OR. -Both operands must be `int` or `Tensor`, or the left operand must be `Tensor` and the right operand must be -`int`. When both operands are `Tensor`, they must have the same shape. When the right operand is `int`, and -the left operand is `Tensor`, the right operand is logically broadcast to match the shape of the `Tensor`. - -### Comparisons - -``` -comparison ::= or_expr (comp_operator or_expr)* -comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' -``` - -A comparison yields a boolean value (`True` or `False`), or if one of the operands is a `Tensor`, a boolean -`Tensor`. Comparisons can be chained arbitrarily as long as they do not yield boolean `Tensors` that have more -than one element. `a op1 b op2 c ...` is equivalent to `a op1 b and b op2 c and ...`. - -#### Value Comparisons - -The operators `<`, `>`, `==`, `>=`, `<=`, and `!=` compare the values of two objects. The two objects generally need to be of -the same type, unless there is an implicit type conversion available between the objects. User-defined types can -be compared if rich comparison methods (e.g., `__lt__`) are defined on them. Built-in type comparison works like -Python: - -- Numbers are compared mathematically. -- Strings are compared lexicographically. -- `lists`, `tuples`, and `dicts` can be compared only to other `lists`, `tuples`, and `dicts` of the same type and are compared using the comparison operator of corresponding elements. - -#### Membership Test Operations - -The operators `in` and `not in` test for membership. `x in s` evaluates to `True` if `x` is a member of `s` and `False` otherwise. -`x not in s` is equivalent to `not x in s`. This operator is supported for `lists`, `dicts`, and `tuples`, and can be used with -user-defined types if they implement the `__contains__` method. - -#### Identity Comparisons - -For all types except `int`, `double`, `bool`, and `torch.device`, operators `is` and `is not` test for the object’s identity; -`x is y` is `True` if and only if `x` and `y` are the same object. For all other types, `is` is equivalent to -comparing them using `==`. `x is not y` yields the inverse of `x is y`. - -### Boolean Operations - -``` -or_test ::= and_test | or_test 'or' and_test -and_test ::= not_test | and_test 'and' not_test -not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test -``` - -User-defined objects can customize their conversion to `bool` by implementing a `__bool__` method. The operator `not` -yields `True` if its operand is false, `False` otherwise. The expression `x` and `y` first evaluates `x`; if it is `False`, its -value (`False`) is returned; otherwise, `y` is evaluated and its value is returned (`False` or `True`). The expression `x` or `y` -first evaluates `x`; if it is `True`, its value (`True`) is returned; otherwise, `y` is evaluated and its value is returned -(`False` or `True`). - -### Conditional Expressions - -``` -conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] -expression ::= conditional_expression -``` - -The expression `x if c else y` first evaluates the condition `c` rather than x. If `c` is `True`, `x` is -evaluated and its value is returned; otherwise, `y` is evaluated and its value is returned. As with if-statements, -`x` and `y` must evaluate to a value of the same type. - -### Expression Lists - -``` -expression_list ::= expression (',' expression)* [','] -starred_item ::= '*' primary -``` - -A starred item can only appear on the left-hand side of an assignment statement, e.g., `a, *b, c = ...`. - -% statements: - -## Simple Statements - -The following section describes the syntax of simple statements that are supported in TorchScript. -It is modeled after [the simple statements chapter of the Python language reference](https://docs.python.org/3/reference/simple_stmts.html). - -### Expression Statements - -``` -expression_stmt ::= starred_expression -starred_expression ::= expression | (starred_item ",")* [starred_item] -starred_item ::= assignment_expression | "*" or_expr -``` - -### Assignment Statements - -``` -assignment_stmt ::= (target_list "=")+ (starred_expression) -target_list ::= target ("," target)* [","] -target ::= identifier - | "(" [target_list] ")" - | "[" [target_list] "]" - | attributeref - | subscription - | slicing - | "*" target -``` - -### Augmented Assignment Statements - -``` -augmented_assignment_stmt ::= augtarget augop (expression_list) -augtarget ::= identifier | attributeref | subscription -augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | - "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" -``` - -### Annotated Assignment Statements - -``` -annotated_assignment_stmt ::= augtarget ":" expression - ["=" (starred_expression)] -``` - -### The `raise` Statement - -``` -raise_stmt ::= "raise" [expression ["from" expression]] -``` - -Raise statements in TorchScript do not support `try\except\finally`. - -### The `assert` Statement - -``` -assert_stmt ::= "assert" expression ["," expression] -``` - -Assert statements in TorchScript do not support `try\except\finally`. - -### The `return` Statement - -``` -return_stmt ::= "return" [expression_list] -``` - -Return statements in TorchScript do not support `try\except\finally`. - -### The `del` Statement - -``` -del_stmt ::= "del" target_list -``` - -### The `pass` Statement - -``` -pass_stmt ::= "pass" -``` - -### The `print` Statement - -``` -print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" -``` - -### The `break` Statement - -``` -break_stmt ::= "break" -``` - -### The `continue` Statement: - -``` -continue_stmt ::= "continue" -``` - -## Compound Statements - -The following section describes the syntax of compound statements that are supported in TorchScript. -The section also highlights how Torchscript differs from regular Python statements. -It is modeled after [the compound statements chapter of the Python language reference](https://docs.python.org/3/reference/compound_stmts.html). - -### The `if` Statement - -Torchscript supports both basic `if/else` and ternary `if/else`. - -#### Basic `if/else` Statement - -``` -if_stmt ::= "if" assignment_expression ":" suite - ("elif" assignment_expression ":" suite) - ["else" ":" suite] -``` - -`elif` statements can repeat for an arbitrary number of times, but it needs to be before `else` statement. - -#### Ternary `if/else` Statement - -``` -if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] -``` - -**Example 1** - -A `tensor` with 1 dimension is promoted to `bool`: - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def fn(x: torch.Tensor): - if x: # The tensor gets promoted to bool - return True - return False - print(fn(torch.rand(1))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - True -``` - -**Example 2** - -A `tensor` with multi dimensions are not promoted to `bool`: - -```python -import torch - -# Multi dimensional Tensors error out. - -@torch.jit.script -def fn(): - if torch.rand(2): - print("Tensor is available") - - if torch.rand(4,5,6): - print("Tensor is available") - -print(fn()) -``` - -Running the above code yields the following `RuntimeError`. - -``` -RuntimeError: The following operation failed in the TorchScript interpreter. -Traceback of TorchScript (most recent call last): -@torch.jit.script -def fn(): - if torch.rand(2): - ~~~~~~~~~~~~ <--- HERE - print("Tensor is available") -RuntimeError: Boolean value of Tensor with more than one value is ambiguous -``` - -If a conditional variable is annotated as `final`, either the true or false branch is evaluated depending on the evaluation of the conditional variable. - -**Example 3** - -In this example, only the True branch is evaluated, since `a` is annotated as `final` and set to `True`: - -```python -import torch - -a : torch.jit.final[Bool] = True - -if a: - return torch.empty(2,3) -else: - return [] -``` - -### The `while` Statement - -``` -while_stmt ::= "while" assignment_expression ":" suite -``` - -`while...else` statements are not supported in Torchscript. It results in a `RuntimeError`. - -### The `for-in` Statement - -``` -for_stmt ::= "for" target_list "in" expression_list ":" suite - ["else" ":" suite] -``` - -`for...else` statements are not supported in Torchscript. It results in a `RuntimeError`. - -**Example 1** - -For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. - -```{eval-rst} -.. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def fn(): - tup = (3, torch.ones(4)) - for x in tup: - print(x) - - fn() -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - 3 - 1 - 1 - 1 - 1 - [ CPUFloatType{4} ] - -``` - -**Example 2** - -For loops on lists: for loops over a `nn.ModuleList` will unroll the body of the loop at compile time, with each member of the module list. - -```python -class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - -model = torch.jit.script(MyModule()) -``` - -### The `with` Statement - -The `with` statement is used to wrap the execution of a block with methods defined by a context manager. - -``` -with_stmt ::= "with" with_item ("," with_item) ":" suite -with_item ::= expression ["as" target] -``` - -- If a target was included in the `with` statement, the return value from the context manager’s `__enter__()` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to `__exit__()`. Three `None` arguments are supplied. -- `try`, `except`, and `finally` statements are not supported inside `with` blocks. -- Exceptions raised within `with` block cannot be suppressed. - -### The `tuple` Statement - -``` -tuple_stmt ::= tuple([iterables]) -``` - -- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. -- You cannot convert a List to Tuple by using this built-in function. - -Unpacking all outputs into a tuple is covered by: - -``` -abc = func() # Function that returns a tuple -a,b = func() -``` - -### The `getattr` Statement - -``` -getattr_stmt ::= getattr(object, name[, default]) -``` - -- Attribute name must be a literal string. -- Module type object is not supported (e.g., torch.\_C). -- Custom class object is not supported (e.g., torch.classes.\*). - -### The `hasattr` Statement - -``` -hasattr_stmt ::= hasattr(object, name) -``` - -- Attribute name must be a literal string. -- Module type object is not supported (e.g., torch.\_C). -- Custom class object is not supported (e.g., torch.classes.\*). - -### The `zip` Statement - -``` -zip_stmt ::= zip(iterable1, iterable2) -``` - -- Arguments must be iterables. -- Two iterables of same outer container type but different length are supported. - -**Example 1** - -Both the iterables must be of the same container type: - -```{eval-rst} -.. testcode:: - - a = [1, 2] # List - b = [2, 3, 4] # List - zip(a, b) # works -``` - -**Example 2** - -This example fails because the iterables are of different container types: - -``` -a = (1, 2) # Tuple -b = [2, 3, 4] # List -zip(a, b) # Runtime error -``` - -Running the above code yields the following `RuntimeError`. - -``` -RuntimeError: Can not iterate over a module list or - tuple with a value that does not have a statically determinable length. -``` - -**Example 3** - -Two iterables of the same container Type but different data type is supported: - -```{eval-rst} -.. testcode:: - - a = [1.3, 2.4] - b = [2, 3, 4] - zip(a, b) # Works -``` - -Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. - -### The `enumerate` Statement - -``` -enumerate_stmt ::= enumerate([iterable]) -``` - -- Arguments must be iterables. -- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList` and `torch.nn.ModuleDict`. - -(python-values-torch-script)= - -## Python Values - -(python-builtin-functions-values-resolution)= - -### Resolution Rules - -When given a Python value, TorchScript attempts to resolve it in the following five different ways: - -- Compilable Python Implementation: - : - When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. - - Example: `torch.jit.Attribute` -- Op Python Wrapper: - : - When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. - - Example: `torch.jit._logging.add_stat_value` -- Python Object Identity Match: - : - For a limited set of `torch.*` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. - - When matched, TorchScript generates a corresponding `SugaredValue` instance that contains lowering logic for these values. - - Example: `torch.jit.isinstance()` -- Name Match: - : - For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding `SugaredValue` instance that implements their functionality. - - Example: `all()` -- Value Snapshot: - : - For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. - - Example: `math.pi` - -(python-builtin-functions-support)= - -### Python Built-in Functions Support - -```{eval-rst} -.. list-table:: TorchScript Support for Python Built-in Functions - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Function - - Support Level - - Notes - * - ``abs()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. - * - ``all()`` - - Full - - - * - ``any()`` - - Full - - - * - ``ascii()`` - - None - - - * - ``bin()`` - - Partial - - Only supports ``Int`` type input. - * - ``bool()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. - * - ``breakpoint()`` - - None - - - * - ``bytearray()`` - - None - - - * - ``bytes()`` - - None - - - * - ``callable()`` - - None - - - * - ``chr()`` - - Partial - - Only ASCII character set is supported. - * - ``classmethod()`` - - Full - - - * - ``compile()`` - - None - - - * - ``complex()`` - - None - - - * - ``delattr()`` - - None - - - * - ``dict()`` - - Full - - - * - ``dir()`` - - None - - - * - ``divmod()`` - - Full - - - * - ``enumerate()`` - - Full - - - * - ``eval()`` - - None - - - * - ``exec()`` - - None - - - * - ``filter()`` - - None - - - * - ``float()`` - - Partial - - Doesn't honor ``__index__`` override. - * - ``format()`` - - Partial - - Manual index specification not supported. | Format type modifier not supported. - * - ``frozenset()`` - - None - - - * - ``getattr()`` - - Partial - - Attribute name must be string literal. - * - ``globals()`` - - None - - - * - ``hasattr()`` - - Partial - - Attribute name must be string literal. - * - ``hash()`` - - Full - - ``Tensor``'s hash is based on identity not numeric value. - * - ``hex()`` - - Partial - - Only supports ``Int`` type input. - * - ``id()`` - - Full - - Only supports ``Int`` type input. - * - ``input()`` - - None - - - * - ``int()`` - - Partial - - ``base`` argument not supported. | Doesn't honor ``__index__`` override. - * - ``isinstance()`` - - Full - - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. - * - ``issubclass()`` - - None - - - * - ``iter()`` - - None - - - * - ``len()`` - - Full - - - * - ``list()`` - - Full - - - * - ``ord()`` - - Partial - - Only ASCII character set is supported. - * - ``pow()`` - - Full - - - * - ``print()`` - - Partial - - ``separate``, ``end`` and ``file`` arguments are not supported. - * - ``property()`` - - None - - - * - ``range()`` - - Full - - - * - ``repr()`` - - None - - - * - ``reversed()`` - - None - - - * - ``round()`` - - Partial - - ``ndigits`` argument is not supported. - * - ``set()`` - - None - - - * - ``setattr()`` - - None - - - * - ``slice()`` - - Full - - - * - ``sorted()`` - - Partial - - ``key`` argument is not supported. - * - ``staticmethod()`` - - Full - - - * - ``str()`` - - Partial - - ``encoding`` and ``errors`` arguments are not supported. - * - ``sum()`` - - Full - - - * - ``super()`` - - Partial - - It can only be used in ``nn.Module``'s ``__init__`` method. - * - ``type()`` - - None - - - * - ``vars()`` - - None - - - * - ``zip()`` - - Full - - - * - ``__import__()`` - - None - - -``` - -(python-builtin-values-support)= - -### Python Built-in Values Support - -```{eval-rst} -.. list-table:: TorchScript Support for Python Built-in Values - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Value - - Support Level - - Notes - * - ``False`` - - Full - - - * - ``True`` - - Full - - - * - ``None`` - - Full - - - * - ``NotImplemented`` - - None - - - * - ``Ellipsis`` - - Full - - - -``` - -(torch-apis-in-torchscript)= - -## torch.\* APIs - -(torch-apis-in-torchscript-rpc)= - -### Remote Procedure Calls - -TorchScript supports a subset of RPC APIs that supports running a function on -a specified remote worker instead of locally. - -Specifically, following APIs are fully supported: - -- `torch.distributed.rpc.rpc_sync()` - : - `rpc_sync()` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_sync`. -- `torch.distributed.rpc.rpc_async()` - : - `rpc_async()` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_async`. -- `torch.distributed.rpc.remote()` - : - `remote.()` executes a remote call on a worker and gets a Remote Reference `RRef` as the return value. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.remote`. - -(torch-apis-in-torchscript-async)= - -### Asynchronous Execution - -TorchScript enables you to create asynchronous computation tasks to make better use -of computation resources. This is done via supporting a list of APIs that are -only usable within TorchScript: - -- `torch.jit.fork()` - : - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. - - Synonymous to `torch.jit._fork()`, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in {meth}`~torch.jit.fork`. -- `torch.jit.wait()` - : - Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. - - Synonymous to `torch.jit._wait()`, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in {meth}`~torch.jit.wait`. - -(torch-apis-in-torchscript-annotation)= - -### Type Annotations - -TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: - -- `torch.jit.annotate()` - : - Provides a type hint to TorchScript where Python 3 style type hints do not work well. - - One common example is to annotate type for expressions like `[]`. `[]` is treated as `List[torch.Tensor]` by default. When a different type is needed, you can use this code to hint TorchScript: `torch.jit.annotate(List[int], [])`. - - More details can be found in {meth}`~torch.jit.annotate` -- `torch.jit.Attribute` - : - Common use cases include providing type hint for `torch.nn.Module` attributes. Because their `__init__` methods are not parsed by TorchScript, `torch.jit.Attribute` should be used instead of `torch.jit.annotate` in the module's `__init__` methods. - - More details can be found in {meth}`~torch.jit.Attribute` -- `torch.jit.Final` - : - An alias for Python's `typing.Final`. `torch.jit.Final` is kept only for backward compatibility reasons. - -(torch-apis-in-torchscript-meta-programming)= - -### Meta Programming - -TorchScript provides a set of utilities to facilitate meta programming: - -- `torch.jit.is_scripting()` - : - Returns a boolean value indicating whether the current program is compiled by `torch.jit.script` or not. - - When used in an `assert` or an `if` statement, the scope or branch where `torch.jit.is_scripting()` evaluates to `False` is not compiled. - - Its value can be evaluated statically at compile time, thus commonly used in `if` statements to stop TorchScript from compiling one of the branches. - - More details and examples can be found in {meth}`~torch.jit.is_scripting` -- `torch.jit.is_tracing()` - : - Returns a boolean value indicating whether the current program is traced by `torch.jit.trace` / `torch.jit.trace_module` or not. - - More details can be found in {meth}`~torch.jit.is_tracing` -- `@torch.jit.ignore` - : - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. - - This allows you to leave code in your model that is not yet TorchScript compatible. - - If a function decorated by `@torch.jit.ignore` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. - - Models with ignored functions cannot be exported. - - More details and examples can be found in {meth}`~torch.jit.ignore` -- `@torch.jit.unused` - : - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. - - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. - - If a function decorated by `@torch.jit.unused` is called from TorchScript, a runtime error will be raised. - - More details and examples can be found in {meth}`~torch.jit.unused` - -(torch-apis-in-torchscript-type-refinement)= - -### Type Refinement - -- `torch.jit.isinstance()` - : - Returns a boolean indicating whether a variable is of the specified type. - - More details about its usage and examples can be found in {meth}`~torch.jit.isinstance`. \ No newline at end of file +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file From 205241a0d5149d05e44dc113dc0273e8eceff9f0 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 17:32:59 +0000 Subject: [PATCH 073/457] [ONNX] Remove legacy dynamo graph extractor (#158262) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158262 Approved by: https://github.com/justinchuby ghstack dependencies: #158258 --- torch/onnx/_internal/_exporter_legacy.py | 332 +----------------- .../_internal/fx/dynamo_graph_extractor.py | 160 --------- .../_internal/fx/onnxfunction_dispatcher.py | 5 +- 3 files changed, 3 insertions(+), 494 deletions(-) delete mode 100644 torch/onnx/_internal/fx/dynamo_graph_extractor.py diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 5447e503801d5..f9ae42b26b84f 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -3,31 +3,18 @@ __all__ = [ - "ExportOptions", - "ONNXRuntimeOptions", - "OnnxRegistry", "enable_fake_mode", ] -import abc import contextlib import dataclasses import logging -import warnings -from collections import defaultdict -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import deprecated +from typing import Any, TYPE_CHECKING import torch import torch._ops -from torch.onnx._internal._lazy_import import onnxscript_apis -from torch.onnx._internal.exporter import _constants -from torch.onnx._internal.fx import ( - decomposition_table, - patcher as patcher, - registration, -) +from torch.onnx._internal.fx import patcher as patcher # We can only import onnx from this module in a type-checking context to ensure that @@ -35,10 +22,6 @@ # 'import onnx' inside of dynamo_export (by way of _assert_dependencies). if TYPE_CHECKING: import io - from collections.abc import Mapping, Sequence - - import onnxruntime - import onnxscript from torch._subclasses import fake_tensor @@ -61,219 +44,6 @@ class ONNXFakeContext: """List of paths of files that contain the model :meth:`state_dict`""" -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", -) -class OnnxRegistry: - """Registry for ONNX functions. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - The registry maintains a mapping from qualified names to symbolic functions under a - fixed opset version. It supports registering custom onnx-script functions and for - dispatcher to dispatch calls to the appropriate function. - - """ - - def __init__(self) -> None: - """Initializes the registry""" - - # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important - # not to directly modify this variable. Instead, access to it should be done through - # the public methods: register_custom_op, get_ops, and is_registered_op. - self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( - defaultdict(list) - ) - - self._opset_version = _constants.TORCHLIB_OPSET - warnings.warn( - f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a " - "different opset version, please register them with register_custom_op." - ) - - self._initiate_registry_from_torchlib() - - @property - def opset_version(self) -> int: - """The ONNX opset version the exporter should target.""" - - return self._opset_version - - def _initiate_registry_from_torchlib(self) -> None: - """Populates the registry with ATen functions from torchlib. - - Args: - torchlib_registry: The torchlib registry to use for populating the registry. - """ - for meta in onnxscript_apis.get_torchlib_ops(): - internal_name_instance = registration.OpName.from_qualified_name( - meta.qualified_name - ) - symbolic_function = registration.ONNXFunction( - onnx_function=meta.function, # type: ignore[arg-type] - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=meta.is_complex, - ) - self._register(internal_name_instance, symbolic_function) - - def _register( - self, - internal_qualified_name: registration.OpName, - symbolic_function: registration.ONNXFunction, - ) -> None: - """Registers a ONNXFunction to an operator. - - Args: - internal_qualified_name: The qualified name of the operator to register: OpName. - symbolic_function: The ONNXFunction to register. - """ - self._registry[internal_qualified_name].append(symbolic_function) - - def register_op( - self, - function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - namespace: str, - op_name: str, - overload: str | None = None, - is_complex: bool = False, - ) -> None: - """Registers a custom operator: torch.ops.... - - Args: - function: The onnx-sctip function to register. - namespace: The namespace of the operator to register. - op_name: The name of the operator to register. - overload: The overload of the operator to register. If it's default overload, - leave it to None. - is_complex: Whether the function is a function that handles complex valued inputs. - - Raises: - ValueError: If the name is not in the form of 'namespace::op'. - """ - internal_name_instance = registration.OpName.from_name_parts( - namespace=namespace, op_name=op_name, overload=overload - ) - symbolic_function = registration.ONNXFunction( - onnx_function=function, - op_full_name=internal_name_instance.qualified_name(), - is_custom=True, - is_complex=is_complex, - ) - self._register(internal_name_instance, symbolic_function) - - def get_op_functions( - self, namespace: str, op_name: str, overload: str | None = None - ) -> list[registration.ONNXFunction] | None: - """Returns a list of ONNXFunctions for the given op: torch.ops.... - - The list is ordered by the time of registration. The custom operators should be - in the second half of the list. - - Args: - namespace: The namespace of the operator to get. - op_name: The name of the operator to get. - overload: The overload of the operator to get. If it's default overload, - leave it to None. - Returns: - A list of ONNXFunctions corresponding to the given name, or None if - the name is not in the registry. - """ - internal_name_instance = registration.OpName.from_name_parts( - namespace=namespace, op_name=op_name, overload=overload - ) - return self._registry.get(internal_name_instance) - - def is_registered_op( - self, namespace: str, op_name: str, overload: str | None = None - ) -> bool: - """Returns whether the given op is registered: torch.ops.... - - Args: - namespace: The namespace of the operator to check. - op_name: The name of the operator to check. - overload: The overload of the operator to check. If it's default overload, - leave it to None. - - Returns: - True if the given op is registered, otherwise False. - """ - functions = self.get_op_functions( - namespace=namespace, op_name=op_name, overload=overload - ) - return functions is not None - - def _all_registered_ops(self) -> set[str]: - """Returns the set of all registered function names.""" - return { - op_name_class.qualified_name() for op_name_class in self._registry.keys() - } - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=None, -) -class ExportOptions: - """Options to influence the TorchDynamo ONNX exporter. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - Attributes: - dynamic_shapes: Shape information hint for input/output tensors. - When ``None``, the exporter determines the most compatible setting. - When ``True``, all input shapes are considered dynamic. - When ``False``, all input shapes are considered static. - fake_context: The fake context used for symbolic tracing. - onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. - """ - - def __init__( - self, - *, - dynamic_shapes: bool | None = True, - fake_context: ONNXFakeContext | None = None, - onnx_registry: OnnxRegistry | None = None, - ): - self.dynamic_shapes = dynamic_shapes - self.fake_context = fake_context - self.onnx_registry = onnx_registry - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=None, -) -class ResolvedExportOptions(ExportOptions): - """Consolidates :class:`ExportOptions` with default values. - All unspecified options from :class:`ExportOptions` are assigned a default value. - This is an internal class and its API may be changed at any time without notice. - """ - - def __init__(self): - from torch.onnx._internal.fx import ( - dynamo_graph_extractor, - onnxfunction_dispatcher, - ) - - self.dynamic_shapes: bool = True - self.fx_tracer: dynamo_graph_extractor.DynamoExport = ( - dynamo_graph_extractor.DynamoExport() - ) - self.fake_context = None - self.onnx_registry: OnnxRegistry = OnnxRegistry() - self.decomposition_table = ( - decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] - self.onnx_registry - ) - ) - self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( - self.onnx_registry, - ) - - @contextlib.contextmanager def enable_fake_mode(): """Enable fake mode for the duration of the context. @@ -346,101 +116,3 @@ def enable_fake_mode(): fake_context.state_dict_paths = tuple( patcher_context.paths, ) # type: ignore[assignment] - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", -) -class ONNXRuntimeOptions: - """Options to influence the execution of the ONNX model through ONNX Runtime. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - Attributes: - session_options: ONNX Runtime session options. - execution_providers: ONNX Runtime execution providers to use during model execution. - execution_provider_options: ONNX Runtime execution provider options. - """ - - session_options: Sequence[onnxruntime.SessionOptions] | None = None - """ONNX Runtime session options.""" - - execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None - """ONNX Runtime execution providers to use during model execution.""" - - execution_provider_options: Sequence[dict[Any, Any]] | None = None - """ONNX Runtime execution provider options.""" - - def __init__( - self, - *, - session_options: Sequence[onnxruntime.SessionOptions] | None = None, - execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, - execution_provider_options: Sequence[dict[Any, Any]] | None = None, - ): - self.session_options = session_options - self.execution_providers = execution_providers - self.execution_provider_options = execution_provider_options - - -class FXGraphExtractor(abc.ABC): - """Abstract interface for FX graph extractor engines. - This class isolates FX extraction logic from the rest of the export logic. - That allows a single ONNX exporter that can leverage different FX graphs.""" - - def __init__(self) -> None: - super().__init__() - - @abc.abstractmethod - def generate_fx( - self, - options: ResolvedExportOptions, - model: torch.nn.Module | Callable, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - """Analyzes user ``model`` and generates a FX graph. - Args: - options: The export options. - model: The user model. - model_args: The model's positional input arguments. - model_kwargs: The model's keyword input arguments. - Returns: - The generated FX Graph. - """ - ... - - # TODO: Design the passes API - @abc.abstractmethod - def pre_export_passes( - self, - options: ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - """Applies pre-export passes to the FX graph. - - Pre-export passes are FX-to-FX graph transformations that make the graph - more palatable for the FX-to-ONNX conversion. - For example, it can be used to flatten model input/output, add explicit - casts to the graph, replace/decompose operators, functionalize the graph, etc. - """ - ... - - -def common_pre_export_passes( - options: ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], -): - # TODO: Import here to prevent circular dependency - from torch.onnx._internal.fx import passes - - # ONNX does not support concept of (implicit) type promotion. - # Insert type casts explicitly where needed. - module = passes.InsertTypePromotion(fx_module).run() - - return module diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py deleted file mode 100644 index 73720ec39d560..0000000000000 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ /dev/null @@ -1,160 +0,0 @@ -# mypy: allow-untyped-defs -# NOTE: This file is referenced by name at -# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. -# introduced by https://github.com/pytorch/pytorch/pull/98894. -# If this file is renamed, moved, etc please update the reference there! - -from __future__ import annotations - -import contextlib -import inspect -from typing import Any, Callable, TYPE_CHECKING - -import torch._dynamo -import torch.fx -from torch.onnx._internal import _exporter_legacy -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - - -class _PyTreeExtensionContext: - """Context manager to register PyTree extension.""" - - _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]] - - def __init__(self) -> None: - self._extensions = {} - # Register PyTree extension for HuggingFace model output. - self._register_huggingface_model_output_extension() - - def __enter__(self): - for class_type, (flatten_func, unflatten_func) in self._extensions.items(): - pytree._private_register_pytree_node( - class_type, - flatten_func, - unflatten_func, - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for class_type in self._extensions: - pytree.SUPPORTED_NODES.pop(class_type) - - def register_pytree_node( - self, - class_type: type, - flatten_func: pytree.FlattenFunc, - unflatten_func: pytree.UnflattenFunc, - ): - """Register PyTree extension for a custom python type. - - Args: - class_type: The custom python type. - flatten_func: The flatten function. - unflatten_func: The unflatten function. - - Raises: - AssertionError: If the custom python type is already registered. - """ - if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions: - # PyTree node already registered. - # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after - # https://github.com/huggingface/transformers/pull/25358. - return - self._extensions[class_type] = (flatten_func, unflatten_func) - - def _register_huggingface_model_output_extension(self): - try: - from transformers import modeling_outputs # type: ignore[import] - except ImportError: - return - - def model_output_flatten( - output: modeling_outputs.ModelOutput, - ) -> tuple[list[Any], pytree.Context]: - return list(output.values()), (type(output), list(output.keys())) - - def model_output_unflatten( - values: list[Any], context: pytree.Context - ) -> modeling_outputs.ModelOutput: - output_type, keys = context - return output_type(**dict(zip(keys, values))) - - # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. - named_model_output_classes = inspect.getmembers( - modeling_outputs, - lambda x: ( - inspect.isclass(x) - and issubclass(x, modeling_outputs.ModelOutput) - and x is not modeling_outputs.ModelOutput - ), - ) - - for _, class_type in named_model_output_classes: - self.register_pytree_node( - class_type, - model_output_flatten, - model_output_unflatten, # type: ignore[arg-type ] - ) - - -class DynamoExport(_exporter_legacy.FXGraphExtractor): - """Generates a FX GraphModule using torch.dynamo.export API - Args: - aten_graph: If True, exports a graph with ATen operators. - If False, exports a graph with Python operators. - """ - - def __init__( - self, - aten_graph: bool | None = None, - ): - super().__init__() - self.aten_graph = aten_graph or True - - def generate_fx( - self, - options: _exporter_legacy.ResolvedExportOptions, - model: torch.nn.Module | Callable, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - # `dynamo.export` does not recognize custom user defined classes as output type. - # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, - # i.e. :class:`torch.Tensor`. - wrapped_model = model - - # Translate callable to FX graph. - # - fake_mode = ( - options.fake_context.fake_mode - if options.fake_context - else contextlib.nullcontext() - ) - fx_mode = "symbolic" if options.dynamic_shapes else "fake" - with fake_mode: # type: ignore[attr-defined] - graph_module, graph_guard = torch._dynamo.export( - wrapped_model, - tracing_mode=fx_mode, - )( - *model_args, - **model_kwargs, - ) - del graph_guard # Unused - torch._dynamo.reset() - - return self.pre_export_passes(options, model, graph_module, model_args) # type: ignore[return-value] - - def pre_export_passes( - self, - options: _exporter_legacy.ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - return _exporter_legacy.common_pre_export_passes( - options, original_model, fx_module, fx_module_args - ) diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py index 516eb36368886..f90e7efd8ac98 100644 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py @@ -25,9 +25,6 @@ graph_building as onnxscript_graph_building, ) - from torch.onnx._internal._exporter_legacy import OnnxRegistry - - logger = logging.getLogger(__name__) @@ -58,7 +55,7 @@ class OnnxFunctionDispatcher: def __init__( self, - onnx_registry: OnnxRegistry, + onnx_registry, ): """Initialize the ONNX Function dispatcher. From cc0faeb80fff17b3d170aa70041865aafb1790a9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 14 Jul 2025 11:35:28 -0700 Subject: [PATCH 074/457] [dynamo][guards] Instruction count for guard eval for development work (#158214) Its turned off by default. Even the code is hidden before of the define preprocessing flag. It will be used only for development work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158214 Approved by: https://github.com/StrongerXi ghstack dependencies: #158215 --- torch/csrc/dynamo/guards.cpp | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 83fb0adbe6c9a..c98119d6adbd3 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -36,6 +36,45 @@ #include #include +// Uncomment next line to count instructions for guard eval. +// #define GUARD_INSTRUCTION_COUNT +#ifdef GUARD_INSTRUCTION_COUNT +#include +#include +#include +#include +#include +#include + +int open_counter() { + perf_event_attr attr{}; + attr.type = PERF_TYPE_HARDWARE; + attr.size = sizeof(attr); + attr.config = PERF_COUNT_HW_INSTRUCTIONS; // retired instructions + attr.disabled = 1; // start stopped + attr.exclude_kernel = 1; // user-space only + attr.exclude_hv = 1; + + return syscall(__NR_perf_event_open, &attr, 0, -1, -1, 0); +} + +uint64_t count_instructions(const std::function& fn) { + int fd = open_counter(); + if (fd == -1) + throw std::runtime_error("perf_event_open failed"); + + ioctl(fd, PERF_EVENT_IOC_RESET, 0); + ioctl(fd, PERF_EVENT_IOC_ENABLE, 0); + fn(); // run the code you care about + ioctl(fd, PERF_EVENT_IOC_DISABLE, 0); + + uint64_t count; + read(fd, &count, sizeof(count)); + close(fd); + return count; +} +#endif + // Certain CPython data structures are defined in `.c` files in earlier Python // versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to // retrieve the underlying tuple and access the item. Before Python 3.12 @@ -5547,6 +5586,13 @@ bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals) { if (root == nullptr) { return false; } + +#ifdef GUARD_INSTRUCTION_COUNT + auto n = count_instructions( + [&] { ((RootGuardManager*)root)->check_nopybind(f_locals); }); + std::cout << "#instructions in guard eval = " << n << std::endl << std::flush; +#endif + return ((RootGuardManager*)root)->check_nopybind(f_locals); } From e4c17d5e1ccd0e730caef484af291243bc1d9cde Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 17:32:59 +0000 Subject: [PATCH 075/457] [ONNX] Remove fx_onnx_interpreter.py (#158282) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158282 Approved by: https://github.com/Skylion007, https://github.com/justinchuby ghstack dependencies: #158258, #158262 --- .../onnx/_internal/fx/fx_onnx_interpreter.py | 718 ------------------ 1 file changed, 718 deletions(-) delete mode 100644 torch/onnx/_internal/fx/fx_onnx_interpreter.py diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py deleted file mode 100644 index 424f2d171b978..0000000000000 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ /dev/null @@ -1,718 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import inspect -import operator -from typing import Callable, TYPE_CHECKING - -import onnxscript -from onnxscript.function_libs.torch_lib import ( - graph_building as onnxscript_graph_building, -) - -import torch -import torch.fx -from torch.onnx import _type_utils as jit_type_utils -from torch.onnx._internal.fx import ( - _pass, - onnxfunction_dispatcher, - type_utils as fx_type_utils, -) -from torch.utils import _pytree - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -def _fx_node_to_onnx_message_formatter( - fn: Callable, - self, - node: torch.fx.Node, - *args, - **kwargs, -) -> str: - return f"FX Node: {node.op}:{node.target}[name={node.name}]. " - - -def _fx_graph_to_onnx_message_formatter( - fn: Callable, - self, - fx_graph_module: torch.fx.GraphModule, - *args, - **kwargs, -) -> str: - return f"FX Graph: {fx_graph_module._get_name()}. " - - -def _retrieve_or_adapt_input_to_graph_set( - fx_node_arg: fx_type_utils.Argument, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, -): - """Map FX value to TorchScript value. - - When creating TorchScript graph from FX graph, we need a mapping from FX variable - to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value. - """ - from onnxscript import opset18 as op - - onnx_tensor = fx_node_arg - if isinstance(onnx_tensor, torch.fx.Node): - # 1. fx_node_arg is a torch.fx.Node, which means - # fx_node_arg stands for the output of that torch.fx.Node. - # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to - # torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name], - # in TorchScript graph. - return fx_name_to_onnxscript_value[onnx_tensor.name] - elif isinstance(onnx_tensor, (tuple, list)) and any( - isinstance(node, torch.fx.Node) - and fx_type_utils.is_torch_symbolic_type(node.meta.get("val")) - for node in onnx_tensor - ): - # This intends to handle dynamic axes. for example, if the input size of op.Expand - # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch - # FX graph. Note that sym variable is mapped to tensor in ONNX Script world) - # calculated by other operators. - sequence_mixed_elements: list[ - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - | list[int] - ] = [] - # onnx_tensor contains a list of scalars which could be one of - # - tensor with empty shape, - # - tensor with tensor with shape (1,), - # - torch.SymInt, - # - int - # - ... - # They should all be promoted to tensor with shape (1,) - # in order to call ONNX's Concat. - for tensor in onnx_tensor: - # Prepare `tensor` as input of ONNX's Concat. - - if isinstance( - tensor, torch.fx.Node - ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")): - # In this case, tensor is a torch.SymInt from Dynamo's perspective. - # It might be mapped to tensor with shape () or (1,) in ONNX. - element_value = fx_name_to_onnxscript_value[tensor.name] - if isinstance( - element_value, onnxscript_graph_building.TorchScriptTensor - ): - # All elements sequence_mixed_elements will be send to onnx's Concat - # as inputs. Therefore, they are required to have the same rank. - # Since tensors with rank=0 (i.e., scalar) cannot be concated, all - # scalars are promoted to tensors with shape (1,). - with onnxscript.evaluator.default_as(tracer): - element_value = op.Reshape( - element_value, # type: ignore[arg-type, type-var] - [1], # type: ignore[arg-type, type-var] - ) - sequence_mixed_elements.append(element_value) - elif isinstance(tensor, int): - # NOTE: op.Concat doesn't support scalar, so we need to wrap it with - # dim, and onnx-script will promote it to tensor(int64) - sequence_mixed_elements.append([tensor]) - else: - raise RuntimeError( - f"Unsupported type in sequence_mixed_elements: {type(tensor)}" - ) - # Concat all the elements in the sequence. - # shapes are mapped to tensors in ONNX graph (TorchScriptGraph), - # so list of sym_ints is concatenated to a tensor before calling ONNX op. - - # For example: - # inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)] - # outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)]) - - # onnx-script auto wraps python number with op.Constants, - # so we don't need to specifically process them. - with onnxscript.evaluator.default_as(tracer): - output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var] - output.dtype = torch.int64 # type: ignore[union-attr] - output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr] - return output - elif isinstance(onnx_tensor, (tuple, list)) and all( - isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor - ): - sequence_elements: list[ - onnxscript_graph_building.TorchScriptTensor - | None - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ] = [] - for tensor in onnx_tensor: - sequence_elements.append( - fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] - ) - return sequence_elements - if isinstance(onnx_tensor, torch.dtype): - onnx_tensor = int( # type: ignore[call-overload] - jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type() - ) - # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But - # if it's in args, we need to set it to string for dispatcher to match schema. - if isinstance(onnx_tensor, torch.device): - # torch.device is not supported by onnxscript (no op). We turn it into - # a string. - return str(onnx_tensor) - # all other cases, we do nothing. - return onnx_tensor - - -def filter_incompatible_and_dtype_convert_kwargs(kwargs): - """Filter out kwargs that are not supported by onnxscript.""" - filtered = {} - for key, value in kwargs.items(): - if key in { - "layout", - "device", - "requires_grad", - "pin_memory", - "memory_format", - "implicit", - }: - continue - if key == "dtype": - if value is None: - # We omit if dtype is not provided, because onnxscript handles the - # default case. - continue - else: - value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload] - filtered[key] = value - return filtered - - -def _fill_tensor_shape_type( - onnxscript_values: onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - name: str, - expected_values: fx_type_utils.META_VALUE_TYPE - | list[fx_type_utils.META_VALUE_TYPE] - | tuple[fx_type_utils.META_VALUE_TYPE | None, ...], -): - """Fill the meta information of onnxscript_values with that from the fx FakeTensor.""" - - if isinstance(expected_values, (list, tuple)) and not isinstance( - onnxscript_values, (list, tuple) - ): - # ex: aten::split - in onnx_dtype: seq(tensor) - # onnxscript_values is a single tensor, but expected_values is a list of tensors. - return - - flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values) - flat_expected_values, _ = _pytree.tree_flatten(expected_values) - for i, (onnxscript_value, expected_value) in enumerate( - zip(flat_onnxscript_values, flat_expected_values) - ): - if expected_value is None: - # There is no shape/type from None. - # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py, - # None could be a valid value for return type, so we need to handle it. - # e.g. the function: meta__scaled_dot_product_flash() in cpu mode. - continue - elif fx_type_utils.is_torch_symbolic_type(expected_value): - # aten::sym_size output is a int, not a tensor, which stands - # for the size of one dim. We treat it as 1-D tensor. - onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype( - expected_value - ) - onnxscript_value.shape = torch.Size([1]) - elif isinstance(expected_value, (int, float, bool)): - onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype( - type(expected_value) - ) - onnxscript_value.shape = torch.Size([]) - elif isinstance(expected_value, complex): - # From complex scalar to real representation - onnxscript_value_to_torch_dtype = ( - fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value)) - ) - onnxscript_value.dtype = ( - fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype) - if onnxscript_value_to_torch_dtype is not None - else None - ) - onnxscript_value.shape = torch.Size([2]) - elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype): - # Like torch.view_as_real, we flatten complex tensors to real tensors with - # additional last dimension of 2 - onnxscript_value.shape = torch.Size((*expected_value.size(), 2)) - # complex64 -> float32, complex128 -> float64, etc. - onnxscript_value.dtype = fx_type_utils.from_complex_to_float( - expected_value.dtype - ) - # Dispatcher needs to know the value is complex - onnxscript_value.is_complex = True - else: - # We set node output sizes to be dynamic to continue the model conversion, - # and inputs are also set to be dynamic in add_input(). - onnxscript_value.shape = expected_value.size() - onnxscript_value.dtype = expected_value.dtype - - # naming - if i > 0: - onnxscript_value.name = f"{name}_{i}" - else: - onnxscript_value.name = name - - -def _fill_in_default_kwargs( - node: torch.fx.Node, -) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: - """Find and Fill in the not provided kwargs with default values.""" - - # TODO: aten::sym_size has overload, but fx graph is using - # overloadpacket for some reasons. - # https://github.com/pytorch/pytorch/issues/97201 - # We manually assigned overload for aten::sym_size. - if hasattr(node.target, "_schema"): - node_schema = node.target._schema # type: ignore[union-attr] - else: - node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr] - - # This function assumes the order of arguments in FX op is the - # same as the order of arguments in TorchScript op. - complete_args: list[fx_type_utils.Argument] = [] - complete_kwargs: dict[str, fx_type_utils.Argument] = {} - - if inspect.isbuiltin(node.target): - complete_args = list(node.args) - else: - for i, expected_arg in enumerate(node_schema.arguments): - if i < len(node.args): - complete_args.append(node.args[i]) - elif expected_arg.name in node.kwargs: - complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] - else: - # Get default from schema. - complete_kwargs[expected_arg.name] = expected_arg.default_value - - return complete_args, complete_kwargs - - -def _wrap_fx_args_as_onnxscript_args( - complete_args: list[fx_type_utils.Argument], - complete_kwargs: dict[str, fx_type_utils.Argument], - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, -) -> tuple[ - Sequence[ - onnxscript_graph_building.TorchScriptTensor - | str - | int - | float - | bool - | list - | complex - | None - ], - dict[str, fx_type_utils.Argument], -]: - """Map all FX arguments of a node to arguments in TorchScript graph.""" - - onnxscript_args = tuple( - _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer) - for arg in complete_args - ) - onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs) - - return onnxscript_args, onnxscript_kwargs - - -class FxOnnxInterpreter: - """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts. - - All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported. - Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node - must be implemented on its own method in this class. - - Each operator's implementation returns either an `onnxscript.OnnxFunction` or - `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can - also raise RuntimeError: If there are no overloaded functions available for the given FX node. - """ - - def run_node( - self, - node, - fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - """Execute a single FX node to produce its ONNX counterpart. - - Args: - node: The FX node to be translated. - fx_graph_module: The FX graph module containing the node. - onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. - onnxscript_graph: The ONNX graph to be populated. - onnxscript_tracer: The tracer to trace the ONNX graph. - fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. - - Raises: - RuntimeError: When a node.op is not supported. - """ - if node.op == "placeholder": - self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value) - elif node.op == "get_attr": - self.get_attr( - node, - onnxscript_graph, - fx_name_to_onnxscript_value, - fx_graph_module, - ) - elif node.op == "call_function": - self.call_function( - node, - onnxscript_tracer, - fx_name_to_onnxscript_value, - onnxfunction_dispatcher, - fx_graph_module, - ) - elif node.op == "call_method": - self.call_method(node) - elif node.op == "call_module": - self.call_module( - node, - onnxscript_graph, - fx_name_to_onnxscript_value, - onnxscript_tracer, - fx_graph_module, - onnxfunction_dispatcher, - ) - elif node.op == "output": - self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) - else: - raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") - - def run( - self, - fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph - | None = None, - ) -> onnxscript_graph_building.TorchScriptGraph: - """Analyze all FX nodes and trigger their ONNX translation. - - Args: - fx_graph_module: FX graph module to be translated. - onnxfunction_dispatcher: ONNX function dispatcher. - parent_onnxscript_graph: The parent TorchScript graph. Must be provided if - `fx_graph_module` is a submodule. If not provided, - `fx_graph_module` is assumed to be the root module. - """ - if parent_onnxscript_graph is not None: - # If parent_onnxscript_graph is provided, we assume fx_graph_module is a - # submodule representing a forward call of an nn.Module. - # Compose package and version where the nn.Module is defined as domain name - # for the local function. - - onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get( - "onnx" - ) - if onnx_meta is None: - raise RuntimeError( - f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. " - f"Only submodules produced by `Modularize` pass is supported in ONNX export." - ) - - onnx_domain = onnx_meta.package_info.to_onnx_domain_string() - else: - # Leave as default domain name for the root module. - onnx_domain = None - - onnxscript_graph = onnxscript_graph_building.TorchScriptGraph( - parent_onnxscript_graph, domain_name=onnx_domain - ) - onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator( - onnxscript_graph - ) - # In the following loop, a TorchScript graph is created to - # represent the input FX graph with ONNX symbols (e.g., onnx::add). - # To connect the values to nodes in the TorchScript graph, we maintain - # fx_name_to_onnxscript_value. Basically, we want to translate - # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node) - # to - # fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name] - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ] = {} - - # TODO: Fix FakeTensorMode limitation asap - # We want to pass list of ints and floats to TorchScript graph correctly - # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may - # receive FakeTensor and results runtime error. In addition, TorchScript-based - # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible - # with FakeTensorMode. - with torch.utils._mode_utils.no_dispatch(): - for node in fx_graph_module.graph.nodes: - self.run_node( - node, - fx_graph_module, - onnxfunction_dispatcher, - onnxscript_graph, - onnxscript_tracer, - fx_name_to_onnxscript_value, - ) - - return onnxscript_graph - - def placeholder( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - # Input of graph. - # The node.meta["val"] is generated by FakeTensorProp. - # NOTE: add_input() intends to create nodes with shape/type - fake_tensor = node.meta.get("val", None) - # NOTE: During the tracing, when inputs are constants, they are represented - # by nodes with node.meta['val'] being None (nn.Module to dynamo_export) - # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export). - # Nonethless, the nodes are not consumed by others, so we don't need to - # create a TorchScriptTensor for them. - if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)): - output = onnxscript_graph.add_input( - input_name=None, - ) - elif isinstance(fake_tensor, torch.Tensor): - # NOTE: ONNX doesn't support tensor of complex64/complex128, so we - # convert them to float32/float64 with real representation. - if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype): - fake_tensor = torch.view_as_real(fake_tensor.resolve_conj()) - output = onnxscript_graph.add_input( - input_name=node.name, - shape=fake_tensor.shape, - dtype=fake_tensor.dtype, - ) - - elif fx_type_utils.is_torch_symbolic_type(fake_tensor): - output = onnxscript_graph.add_input( - input_name=node.name, - shape=torch.Size([]), - dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor), - ) - else: - raise RuntimeError( - f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" - ) - assert output is not None, ( - f"Node creates None with target={node.target} and name={node.name}" - ) - - assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) - assert isinstance(output, onnxscript.tensor.Tensor) - - fx_name_to_onnxscript_value[node.name] = output - - def call_function( - self, - node: torch.fx.Node, - onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - fx_graph_module: torch.fx.GraphModule, - ): - # aten ops and other stateless functions. - if node.target == operator.getitem and isinstance( - fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index] - tuple, - ): - onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] - index = node.args[1] - value = onnx_tensor_tuple[index] # type: ignore[index] - assert value is not None, ( - f"Node creates None with target={node.target} and name={node.name}" - ) - assert isinstance( - value, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), type(value) - - fx_name_to_onnxscript_value[node.name] = value - return - - # Map FX inputs to ONNX inputs and fill optional inputs with default values. - # torch_args and torch_kwargs are for op-level validation - fx_args, fx_kwargs = _fill_in_default_kwargs(node) - - onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args( - fx_args, - fx_kwargs, - fx_name_to_onnxscript_value, - onnxscript_tracer, - ) - # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to - # function signature in OpSchema, and find the best matched overload. - symbolic_fn = onnxfunction_dispatcher.dispatch( - node=node, - onnx_args=onnx_args, # type: ignore[arg-type] - onnx_kwargs=onnx_kwargs, - ) - with onnxscript.evaluator.default_as(onnxscript_tracer): - output: ( - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ) = symbolic_fn(*onnx_args, **onnx_kwargs) - assert output is not None, ( - f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" - ) - # Assign type and shape from fx graph. - _fill_tensor_shape_type(output, node.name, node.meta["val"]) - # One fx node could produce multiple outputs (e.g., tuple of tensors); in - # that case, v is a tuple of TorchScriptTensors. - assert isinstance( - output, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), type(output) - fx_name_to_onnxscript_value[node.name] = output - - def output( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - if isinstance(node.args[0], torch.fx.Node): - onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] - onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) - else: - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - flat_args, _ = _pytree.tree_flatten(node.args[0]) - for arg in flat_args: - assert isinstance(arg, torch.fx.Node), ( - f"arg must be a torch.fx.Node, not {type(arg)}" - ) - onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] - onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) - - def call_method(self, node: torch.fx.Node): - # TODO(wechi): Support call_method. - raise RuntimeError("call_method is not supported yet.") - - def call_module( - self, - node: torch.fx.Node, - parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - root_fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - ) -> None: - """Export a fx.GraphModule submodule to ONNXScript graph. - - The export process specifically targets `call_module` nodes that are created by - the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule - by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX - function nodes. The related `sub_module` is then exported as an ONNX model local function, - which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current - `onnxscript_graph` as its parent. - - Args: - node: The call_module node in the FX graph that represents the submodule call. - parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and - function node belong. - fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value. - tracer: The tracer used to trace the ONNXScript graph. - root_fx_graph_module: The root FX module. - onnxfunction_dispatcher: The dispatcher. - """ - assert isinstance(node.target, str), ( - f"node.target must be a str, not {type(node.target)} for node {node}." - ) - - sub_module = root_fx_graph_module.get_submodule(node.target) - - assert isinstance(sub_module, torch.fx.GraphModule), ( - f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." - ) - - sub_onnxscript_graph = self.run( - sub_module, onnxfunction_dispatcher, parent_onnxscript_graph - ) - - onnx_args, _ = _wrap_fx_args_as_onnxscript_args( - list(node.args), {}, fx_name_to_onnxscript_value, tracer - ) - - # TODO: We may want to consider other naming styles. The goal is to be stable and - # unique such that it can be easily identified in case of kernel substitution. - # Example for current style is combination of qualified module class name and - # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`. - # Other naming styles such as qualified module class name made unique can also - # be considered. - unique_module_name = f"{sub_module._get_name()}_{node.target}" - - outputs: ( - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] - unique_module_name, sub_onnxscript_graph, onnx_args - ) - - assert isinstance( - outputs, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), f"Unexpected outputs type {type(outputs)} for node {node}." - - _fill_tensor_shape_type(outputs, node.name, node.meta["val"]) - fx_name_to_onnxscript_value[node.name] = outputs - - # Skip op_level_validation for call_module. Subgraph nodes are validated individually. - - def get_attr( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - fx_graph_module: torch.fx.GraphModule, - ): - assert isinstance(node.target, str), f"node.target {node.target} is not a str." - attr_tensor = getattr(fx_graph_module, node.target) - assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor." - - # Parameter/buffer name cannot contain "." - # Revert from "/" to restore namespace formatting. - input_ = onnxscript_graph.add_initializer( - name=node.target.replace("/", "."), - value=attr_tensor, - ) - - assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor) - assert isinstance(input_, onnxscript.tensor.Tensor) - fx_name_to_onnxscript_value[node.name] = input_ From 0640cfa38c1426a41ab4a0b3e3dab7c730cdc2ad Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 15 Jul 2025 20:57:23 +0000 Subject: [PATCH 076/457] [2/n] Remove references to TorchScript in PyTorch docs (#158306) Summary: Removed jit_language_reference.md Test Plan: CI Rollback Plan: Differential Revision: D78308133 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158306 Approved by: https://github.com/svekars, https://github.com/zhxchen17 --- docs/source/jit_language_reference.md | 922 +------------------------- torch/jit/_script.py | 2 +- torch/jit/supported_ops.py | 6 +- 3 files changed, 7 insertions(+), 923 deletions(-) diff --git a/docs/source/jit_language_reference.md b/docs/source/jit_language_reference.md index 9737309482080..f2b31768e2d58 100644 --- a/docs/source/jit_language_reference.md +++ b/docs/source/jit_language_reference.md @@ -30,923 +30,7 @@ # TorchScript Language Reference -TorchScript is a statically typed subset of Python that can either be written directly (using -the {func}`@torch.jit.script ` decorator) or generated automatically from Python code via -tracing. When using tracing, code is automatically converted into this subset of -Python by recording only the actual operators on tensors and simply executing and -discarding the other surrounding Python code. - -When writing TorchScript directly using `@torch.jit.script` decorator, the programmer must -only use the subset of Python supported in TorchScript. This section documents -what is supported in TorchScript as if it were a language reference for a stand -alone language. Any features of Python not mentioned in this reference are not -part of TorchScript. See `Builtin Functions` for a complete reference of available -PyTorch tensor methods, modules, and functions. - -As a subset of Python, any valid TorchScript function is also a valid Python -function. This makes it possible to `disable TorchScript` and debug the -function using standard Python tools like `pdb`. The reverse is not true: there -are many valid Python programs that are not valid TorchScript programs. -Instead, TorchScript focuses specifically on the features of Python that are -needed to represent neural network models in PyTorch. - -(types)= - -(supported-type)= - -## Types - -The largest difference between TorchScript and the full Python language is that -TorchScript only supports a small set of types that are needed to express neural -net models. In particular, TorchScript supports: - -```{eval-rst} -.. csv-table:: - :header: "Type", "Description" - - "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" - "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" - "``bool``", "A boolean value" - "``int``", "A scalar integer" - "``float``", "A scalar floating point number" - "``str``", "A string" - "``List[T]``", "A list of which all members are type ``T``" - "``Optional[T]``", "A value which is either None or type ``T``" - "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." - "``T``", "A {ref}`TorchScript Class`" - "``E``", "A {ref}`TorchScript Enum`" - "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" - "``Union[T0, T1, ...]``", "One of the subtypes ``T0``, ``T1``, etc." -``` - -Unlike Python, each variable in TorchScript function must have a single static type. -This makes it easier to optimize TorchScript functions. - -Example (a type mismatch) - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def an_error(x): - if x: - r = torch.rand(1) - else: - r = 4 - return r - -``` - -```{eval-rst} -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: - @torch.jit.script - def an_error(x): - if x: - ~~~~~ - r = torch.rand(1) - ~~~~~~~~~~~~~~~~~ - else: - ~~~~~ - r = 4 - ~~~~~ <--- HERE - return r - and was used here: - else: - r = 4 - return r - ~ <--- HERE... -``` - -### Unsupported Typing Constructs - -TorchScript does not support all features and types of the {mod}`typing` module. Some of these -are more fundamental things that are unlikely to be added in the future while others -may be added if there is enough user demand to make it a priority. - -These types and features from the {mod}`typing` module are unavailable in TorchScript. - -```{eval-rst} -.. csv-table:: - :header: "Item", "Description" - - ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" - ":any:`typing.NoReturn`", "Not implemented" - ":any:`typing.Sequence`", "Not implemented" - ":any:`typing.Callable`", "Not implemented" - ":any:`typing.Literal`", "Not implemented" - ":any:`typing.ClassVar`", "Not implemented" - ":any:`typing.Final`", "This is supported for :any:`module attributes ` class attribute annotations but not for functions" - ":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used" - ":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released" - "Type aliases", "Not implemented" - "Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not" - "NewType", "Unlikely to be implemented" - "Generics", "Unlikely to be implemented" -``` - -Any other functionality from the {any}`typing` module not explicitly listed in this documentation is unsupported. - -### Default Types - -By default, all parameters to a TorchScript function are assumed to be Tensor. -To specify that an argument to a TorchScript function is another type, it is possible to use -MyPy-style type annotations using the types listed above. - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def foo(x, tup): - # type: (int, Tuple[Tensor, Tensor]) -> Tensor - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... -``` - -:::{note} -It is also possible to annotate types with Python 3 type hints from the -`typing` module. - -```{eval-rst} -.. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... -``` -::: - -An empty list is assumed to be `List[Tensor]` and empty dicts -`Dict[str, Tensor]`. To instantiate an empty list or dict of other types, -use `Python 3 type hints`. - -Example (type annotations for Python 3): - -```{eval-rst} -.. testcode:: - - import torch - import torch.nn as nn - from typing import Dict, List, Tuple - - class EmptyDataStructures(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: - # This annotates the list to be a `List[Tuple[int, float]]` - my_list: List[Tuple[int, float]] = [] - for i in range(10): - my_list.append((i, x.item())) - - my_dict: Dict[str, int] = {} - return my_list, my_dict - - x = torch.jit.script(EmptyDataStructures()) - - - -``` - -### Optional Type Refinement - -TorchScript will refine the type of a variable of type `Optional[T]` when -a comparison to `None` is made inside the conditional of an if-statement or checked in an `assert`. -The compiler can reason about multiple `None` checks that are combined with -`and`, `or`, and `not`. Refinement will also occur for else blocks of if-statements -that are not explicitly written. - -The `None` check must be within the if-statement's condition; assigning -a `None` check to a variable and using it in the if-statement's condition will -not refine the types of variables in the check. -Only local variables will be refined, an attribute like `self.x` will not and must assigned to -a local variable to be refined. - -Example (refining types on parameters and locals): - -```{eval-rst} -.. testcode:: - - import torch - import torch.nn as nn - from typing import Optional - - class M(nn.Module): - z: Optional[int] - - def __init__(self, z): - super().__init__() - # If `z` is None, its type cannot be inferred, so it must - # be specified (above) - self.z = z - - def forward(self, x, y, z): - # type: (Optional[int], Optional[int], Optional[int]) -> int - if x is None: - x = 1 - x = x + 1 - - # Refinement for an attribute by assigning it to a local - z = self.z - if y is not None and z is not None: - x = y + z - - # Refinement via an `assert` - assert z is not None - x += z - return x - - module = torch.jit.script(M(2)) - module = torch.jit.script(M(None)) - -``` - -(TorchScript Class)= - -(TorchScript Classes)= - -(torchscript-classes)= - -### TorchScript Classes - :::{warning} -TorchScript class support is experimental. Currently it is best suited -for simple record-like types (think a `NamedTuple` with methods -attached). -::: - -Python classes can be used in TorchScript if they are annotated with {func}`@torch.jit.script `, -similar to how you would declare a TorchScript function: - -```{eval-rst} -.. testcode:: - :skipif: True # TODO: fix the source file resolving so this can be tested - - @torch.jit.script - class Foo: - def __init__(self, x, y): - self.x = x - - def aug_add_x(self, inc): - self.x += inc - -``` - -This subset is restricted: - -- All functions must be valid TorchScript functions (including `__init__()`). - -- Classes must be new-style classes, as we use `__new__()` to construct them with pybind11. - -- TorchScript classes are statically typed. Members can only be declared by assigning to - self in the `__init__()` method. - - > For example, assigning to `self` outside of the `__init__()` method: - > - > ``` - > @torch.jit.script - > class Foo: - > def assign_x(self): - > self.x = torch.rand(2, 3) - > ``` - > - > Will result in: - > - > ``` - > RuntimeError: - > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: - > def assign_x(self): - > self.x = torch.rand(2, 3) - > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE - > ``` - -- No expressions except method definitions are allowed in the body of the class. - -- No support for inheritance or any other polymorphism strategy, except for inheriting - from `object` to specify a new-style class. - -After a class is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type: - -``` -# Declare a TorchScript class -@torch.jit.script -class Pair: - def __init__(self, first, second): - self.first = first - self.second = second - -@torch.jit.script -def sum_pair(p): - # type: (Pair) -> Tensor - return p.first + p.second - -p = Pair(torch.rand(2, 3), torch.rand(2, 3)) -print(sum_pair(p)) -``` - -(TorchScript Enum)= - -(TorchScript Enums)= - -(torchscript-enums)= - -### TorchScript Enums - -Python enums can be used in TorchScript without any extra annotation or code: - -``` -from enum import Enum - - -class Color(Enum): - RED = 1 - GREEN = 2 - -@torch.jit.script -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - - return x == y -``` - -After an enum is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type. The type of the values of an enum must be `int`, -`float`, or `str`. All values must be of the same type; heterogeneous types for enum -values are not supported. - -### Named Tuples - -Types produced by {func}`collections.namedtuple ` can be used in TorchScript. - -```{eval-rst} -.. testcode:: - - import torch - import collections - - Point = collections.namedtuple('Point', ['x', 'y']) - - @torch.jit.script - def total(point): - # type: (Point) -> Tensor - return point.x + point.y - - p = Point(x=torch.rand(3), y=torch.rand(3)) - print(total(p)) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... - -``` - -(jit_iterables)= - -### Iterables - -Some functions (for example, {any}`zip` and {any}`enumerate`) can only operate on iterable types. -Iterable types in TorchScript include `Tensor`s, lists, tuples, dictionaries, strings, -{any}`torch.nn.ModuleList` and {any}`torch.nn.ModuleDict`. - -## Expressions - -The following Python Expressions are supported. - -### Literals - -``` -True -False -None -'string literals' -"string literals" -3 # interpreted as int -3.4 # interpreted as a float -``` - -#### List Construction - -An empty list is assumed have type `List[Tensor]`. -The types of other list literals are derived from the type of the members. -See [Default Types] for more details. - -``` -[3, 4] -[] -[torch.rand(3), torch.rand(4)] -``` - -#### Tuple Construction - -``` -(3, 4) -(3,) -``` - -#### Dict Construction - -An empty dict is assumed have type `Dict[str, Tensor]`. -The types of other dict literals are derived from the type of the members. -See [Default Types] for more details. - -``` -{'hello': 3} -{} -{'a': torch.rand(3), 'b': torch.rand(4)} -``` - -### Variables - -See [Variable Resolution] for how variables are resolved. - -``` -my_variable_name -``` - -### Arithmetic Operators - -``` -a + b -a - b -a * b -a / b -a ^ b -a @ b -``` - -### Comparison Operators - -``` -a == b -a != b -a < b -a > b -a <= b -a >= b -``` - -### Logical Operators - -``` -a and b -a or b -not b -``` - -### Subscripts and Slicing - -``` -t[0] -t[-1] -t[0:2] -t[1:] -t[:1] -t[:] -t[0, 1] -t[0, 1:2] -t[0, :1] -t[-1, 1:, 0] -t[1:, -1, 0] -t[i:j, i] -``` - -### Function Calls - -Calls to `builtin functions` - -``` -torch.rand(3, dtype=torch.int) -``` - -Calls to other script functions: - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def foo(x): - return x + 1 - - @torch.jit.script - def bar(x): - return foo(x) -``` - -### Method Calls - -Calls to methods of builtin types like tensor: `x.mm(y)` - -On modules, methods must be compiled before they can be called. The TorchScript -compiler recursively compiles methods it sees when compiling other methods. By default, -compilation starts on the `forward` method. Any methods called by `forward` will -be compiled, and any methods called by those methods, and so on. To start compilation at -a method other than `forward`, use the {func}`@torch.jit.export ` decorator -(`forward` implicitly is marked `@torch.jit.export`). - -Calling a submodule directly (e.g. `self.resnet(input)`) is equivalent to -calling its `forward` method (e.g. `self.resnet.forward(input)`). - -```{eval-rst} -.. testcode:: - :skipif: torchvision is None - - import torch - import torch.nn as nn - import torchvision - - class MyModule(nn.Module): - def __init__(self): - super().__init__() - means = torch.tensor([103.939, 116.779, 123.68]) - self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) - resnet = torchvision.models.resnet18() - self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) - - def helper(self, input): - return self.resnet(input - self.means) - - def forward(self, input): - return self.helper(input) - - # Since nothing in the model calls `top_level_method`, the compiler - # must be explicitly told to compile this method - @torch.jit.export - def top_level_method(self, input): - return self.other_helper(input) - - def other_helper(self, input): - return input + 10 - - # `my_script_module` will have the compiled methods `forward`, `helper`, - # `top_level_method`, and `other_helper` - my_script_module = torch.jit.script(MyModule()) - -``` - -### Ternary Expressions - -``` -x if x > y else y -``` - -### Casts - -``` -float(ten) -int(3.5) -bool(ten) -str(2)`` -``` - -### Accessing Module Parameters - -``` -self.my_parameter -self.my_submodule.my_parameter -``` - -## Statements - -TorchScript supports the following types of statements: - -### Simple Assignments - -``` -a = b -a += b # short-hand for a = a + b, does not operate in-place on a -a -= b -``` - -### Pattern Matching Assignments - -``` -a, b = tuple_or_list -a, b, *c = a_tuple -``` - -Multiple Assignments - -``` -a = b, c = tup -``` - -### Print Statements - -``` -print("the result of an add:", a + b) -``` - -### If Statements - -``` -if a < 4: - r = -a -elif a < 3: - r = a + a -else: - r = 3 * a -``` - -In addition to bools, floats, ints, and Tensors can be used in a conditional -and will be implicitly casted to a boolean. - -### While Loops - -``` -a = 0 -while a < 4: - print(a) - a += 1 -``` - -### For loops with range - -``` -x = 0 -for i in range(10): - x *= i -``` - -### For loops over tuples - -These unroll the loop, generating a body for -each member of the tuple. The body must type-check correctly for each member. - -``` -tup = (3, torch.rand(4)) -for x in tup: - print(x) -``` - -### For loops over constant nn.ModuleList - -To use a `nn.ModuleList` inside a compiled method, it must be marked -constant by adding the name of the attribute to the `__constants__` -list for the type. For loops over a `nn.ModuleList` will unroll the body of the -loop at compile time, with each member of the constant module list. - -```{eval-rst} -.. testcode:: - - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - - class MyModule(torch.nn.Module): - __constants__ = ['mods'] - - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - - - m = torch.jit.script(MyModule()) - - -``` - -### Break and Continue - -``` -for i in range(5): - if i == 1: - continue - if i == 3: - break - print(i) -``` - -### Return - -``` -return a, b -``` - -## Variable Resolution - -TorchScript supports a subset of Python's variable resolution (i.e. scoping) -rules. Local variables behave the same as in Python, except for the restriction -that a variable must have the same type along all paths through a function. -If a variable has a different type on different branches of an if statement, it -is an error to use it after the end of the if statement. - -Similarly, a variable is not allowed to be used if it is only *defined* along some -paths through the function. - -Example: - -```{eval-rst} -.. testcode:: - - @torch.jit.script - def foo(x): - if x < 0: - y = 4 - print(y) -``` - -```{eval-rst} -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - y is not defined in the false branch... - @torch.jit.script... - def foo(x): - if x < 0: - ~~~~~~~~~ - y = 4 - ~~~~~ <--- HERE - print(y) - and was used here: - if x < 0: - y = 4 - print(y) - ~ <--- HERE... -``` - -Non-local variables are resolved to Python values at compile time when the -function is defined. These values are then converted into TorchScript values using -the rules described in [Use of Python Values]. - -## Use of Python Values - -To make writing TorchScript more convenient, we allow script code to refer -to Python values in the surrounding scope. For instance, any time there is a -reference to `torch`, the TorchScript compiler is actually resolving it to the -`torch` Python module when the function is declared. These Python values are -not a first class part of TorchScript. Instead they are de-sugared at compile-time -into the primitive types that TorchScript supports. This depends -on the dynamic type of the Python valued referenced when compilation occurs. -This section describes the rules that are used when accessing Python values in TorchScript. - -### Functions - -TorchScript can call Python functions. This functionality is very useful when -incrementally converting a model to TorchScript. The model can be moved function-by-function -to TorchScript, leaving calls to Python functions in place. This way you can incrementally -check the correctness of the model as you go. - -```{eval-rst} -.. autofunction:: torch.jit.is_scripting -``` - -```{eval-rst} -.. autofunction:: torch.jit.is_tracing - -``` - -### Attribute Lookup On Python Modules - -TorchScript can lookup attributes on modules. `Builtin functions` like `torch.add` -are accessed this way. This allows TorchScript to call functions defined in -other modules. - -(constant)= - -### Python-defined Constants - -TorchScript also provides a way to use constants that are defined in Python. -These can be used to hard-code hyper-parameters into the function, or to -define universal constants. There are two ways of specifying that a Python -value should be treated as a constant. - -1. Values looked up as attributes of a module are assumed to be constant: - -```{eval-rst} -.. testcode:: - - import math - import torch - - @torch.jit.script - def fn(): - return math.pi -``` - -2. Attributes of a ScriptModule can be marked constant by annotating them with `Final[T]` - -``` -import torch -import torch.nn as nn - -class Foo(nn.Module): - # `Final` from the `typing_extensions` module can also be used - a : torch.jit.Final[int] - - def __init__(self): - super().__init__() - self.a = 1 + 4 - - def forward(self, input): - return self.a + input - -f = torch.jit.script(Foo()) -``` - -Supported constant Python types are - -- `int` -- `float` -- `bool` -- `torch.device` -- `torch.layout` -- `torch.dtype` -- tuples containing supported types -- `torch.nn.ModuleList` which can be used in a TorchScript for loop - -(module-attributes)= -(Module Attributes)= - -### Module Attributes - -The `torch.nn.Parameter` wrapper and `register_buffer` can be used to assign -tensors to a module. Other values assigned to a module that is compiled -will be added to the compiled module if their types can be inferred. All [types] -available in TorchScript can be used as module attributes. Tensor attributes are -semantically the same as buffers. The type of empty lists and dictionaries and `None` -values cannot be inferred and must be specified via -[PEP 526-style](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations) class annotations. -If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute -to the resulting {class}`ScriptModule`. - -Example: - -```{eval-rst} -.. testcode:: - - from typing import List, Dict - - class Foo(nn.Module): - # `words` is initialized as an empty list, so its type must be specified - words: List[str] - - # The type could potentially be inferred if `a_dict` (below) was not - # empty, but this annotation ensures `some_dict` will be made into the - # proper type - some_dict: Dict[str, int] - - def __init__(self, a_dict): - super().__init__() - self.words = [] - self.some_dict = a_dict - - # `int`s can be inferred - self.my_int = 10 - - def forward(self, input): - # type: (str) -> int - self.words.append(input) - return self.some_dict[input] + self.my_int - - f = torch.jit.script(Foo({'hi': 2})) -``` +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 79442f57d3063..a9a95cdace452 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1246,7 +1246,7 @@ def script( subsequently passed by reference between Python and TorchScript with zero copy overhead. ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists - and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. + and as a decorator ``@torch.jit.script`` for torchscript-classes and functions. Args: obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 791a11a9b3aa7..98229edff6ee8 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -243,8 +243,8 @@ def _get_global_builtins(): "getattr": "Attribute name must be a literal string", "hasattr": "Attribute name must be a literal string", "isinstance": "Result is static", - "zip": "Arguments must be iterable. See :ref:`Iterables ` for details.", - "enumerate": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "zip": "Arguments must be iterable.", + "enumerate": "Arguments must be iterable.", "range": "Can only be used as an iterator in a for loop", } @@ -295,7 +295,7 @@ def _get_global_builtins(): {schemaless_ops_str} -The following functions will use the corresponding magic method on :any:`TorchScript classes` +The following functions will use the corresponding magic method on TorchScript classes .. csv-table:: :header: "Function", "Magic Method" From 3f83e3eeca0645f4b2cd16fa7d5a591e9cf810d4 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 15 Jul 2025 17:32:59 +0000 Subject: [PATCH 077/457] [ONNX] Remove legacy registration and dispatcher (#158283) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158283 Approved by: https://github.com/Skylion007, https://github.com/justinchuby ghstack dependencies: #158258, #158262, #158282 --- .../onnx/_internal/fx/decomposition_table.py | 116 --- .../_internal/fx/onnxfunction_dispatcher.py | 728 ------------------ torch/onnx/_internal/fx/registration.py | 87 --- 3 files changed, 931 deletions(-) delete mode 100644 torch/onnx/_internal/fx/decomposition_table.py delete mode 100644 torch/onnx/_internal/fx/onnxfunction_dispatcher.py delete mode 100644 torch/onnx/_internal/fx/registration.py diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py deleted file mode 100644 index 71715e1ad2344..0000000000000 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ /dev/null @@ -1,116 +0,0 @@ -# mypy: allow-untyped-defs -"""Dispatcher for AtenLib functions from onnx-script.""" - -from __future__ import annotations - -from typing import Callable - -import torch -import torch._ops -import torch.fx -from torch.onnx._internal.fx import registration - - -def _create_onnx_supports_op_overload_table( - registry, -) -> set[torch._ops.OperatorBase | Callable]: - """ - Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. - - Args: - registry (OnnxRegistry): The ONNX registry for PyTorch. - - Returns: - A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. - """ - table: set[torch._ops.OperatorBase | Callable] = set() - - # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`, - # but retrievable via explicit lookup. - # https://github.com/pytorch/pytorch/issues/99681 - # This is a workaround to make sure we register ONNX symbolic functions for these. - onnx_supported_aten_lookup_table = [ - k.split("::")[1].split(".")[0] - for k in registry._all_registered_ops() - if k.startswith("aten::") - ] - - for op_namespace in (torch.ops.aten, torch.ops.prims): - attr_names = dir(op_namespace) - if op_namespace is torch.ops.aten: - attr_names += onnx_supported_aten_lookup_table - for attr_name in attr_names: - if not hasattr(op_namespace, attr_name): - # torchlib owns some attributes that are not aten ops. - continue - op_overload_packet = getattr(op_namespace, attr_name) - if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): - continue - - for overload_name in op_overload_packet.overloads(): - op_overload = getattr(op_overload_packet, overload_name) - internal_op_name = registration.OpName.from_qualified_name( - qualified_name=op_overload.name() - ) - # NOTE: If the overload is supported in registry or it's default overload is supported in registry, - # we add it to the table. - if registry.is_registered_op( - namespace=internal_op_name.namespace, - op_name=internal_op_name.op_name, - overload=internal_op_name.overload, - ) or registry.is_registered_op( - namespace=internal_op_name.namespace, - op_name=internal_op_name.op_name, - overload=None, - ): - # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc - # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add". - # This is applied to all ops under torch.ops.aten. - table.add(op_overload) - return table - - -def create_onnx_friendly_decomposition_table( - registry, -) -> dict[torch._ops.OperatorBase, Callable]: - """ - This function creates a dictionary of op overloads and their decomposition functions - for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, - its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's - built-in aten-to-aten decomposition. - - Args: - registry: The ONNX registry for PyTorch. - - Returns: - Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding - decomposition functions. - """ - decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} - # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g., - # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add". - _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry) - - # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single - # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your - # definitions in a single TORCH_LIBRARY block. - for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): - # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they - # are not generally supported by ONNX. - # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX - # symbolic function. - if ( - "torch._refs" in decomp_fn.__module__ - or op_overload in _ONNX_SUPPORT_OP_OVERLOADS - ): - continue - decomposition_table[op_overload] = decomp_fn - - # NOTE: There are ops in core ATen and under torch._refs, - # that are not decomposed to prim::ops. We need to pick them - # back - for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): - if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: - continue - decomposition_table[op_overload] = decomp_fn - return decomposition_table diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py deleted file mode 100644 index f90e7efd8ac98..0000000000000 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ /dev/null @@ -1,728 +0,0 @@ -# mypy: allow-untyped-defs -"""Dispatcher for AtenLib functions from onnx-script. - -This is a deprecated module to be removed. -""" - -from __future__ import annotations - -import logging -import operator -import types -from typing import Any, TYPE_CHECKING - -import torch -import torch._ops -import torch.fx -from torch.onnx._internal.fx import registration, type_utils as fx_type_utils - - -if TYPE_CHECKING: - from collections.abc import Sequence - - import onnxscript # type: ignore[import] - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] - graph_building as onnxscript_graph_building, - ) - -logger = logging.getLogger(__name__) - - -class OnnxFunctionDispatcher: - """A dispatcher that finds the best ONNX Function for ATen/Custom operators. - - It uses the `torch.ops` name to find the function. If not found, it falls back to default. - Otherwise, the best match is found among all function overloads. An exact match has - higher precedence over the closest ones. - - Below is a breakdown on how the dispatch mechanism works: - - 1. Use the torch.ops name to find the function: - a. Check if the ATen overload exists in the registry. - b. If not, check if the default overload exists in the registry. - - 2. Find the nearest match among all overloaded functions: - a. If the types match perfectly, select the function. - b. Otherwise, find the nearest one with the highest matching score. Because of - the potential wrongly annotated dtypes and attributes matching, we use - nearest match to find the best function once the aten name is targeted. - - 3. Tie-breaker: If there are multiple nearest matches, we will select the one with - the highest matching score. - - NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. - """ - - def __init__( - self, - onnx_registry, - ): - """Initialize the ONNX Function dispatcher. - - Args: - onnx_registry: The ONNX registry. - """ - self.onnx_registry = onnx_registry - - def dispatch( - self, - node: torch.fx.Node, - onnx_args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - onnx_kwargs: dict[str, fx_type_utils.Argument], - ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction: - """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments. - Args: - node: The TorchFX node to dispatch the function for. - onnx_args: The arguments of the ONNX function. - onnx_kwargs: The keyword arguments of the ONNX function. - - Returns: - Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. - Raises: - RuntimeError: If there are no overloaded functions available for the given FX node. - """ - # If there are no overloaded functions available for the given FX node, raise an - # unsupported error - default_and_custom_functions = self.get_function_overloads(node) - - # If there are overloaded functions available, we will find one that perfect or - # nearest matches the given arguments and keyword arguments - return self._find_the_perfect_or_nearest_match_onnxfunction( - node, - default_and_custom_functions, - onnx_args, - onnx_kwargs, - ) - - def _filter_or_keep_complex( - self, - node, - default_and_custom_functions: list[registration.ONNXFunction], - ) -> list[registration.ONNXFunction]: - """Filter the complex functions if the input has complex dtype.""" - - args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args] - if any(args_with_complex_dtype): - default_and_custom_functions = [ - func for func in default_and_custom_functions if func.is_complex - ] - # If we can't find the complex function group, raise error. - if not default_and_custom_functions: - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Cannot find any COMPLEX symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - else: - default_and_custom_functions = [ - func for func in default_and_custom_functions if not func.is_complex - ] - # If we can't find the complex function group, raise error. - if not default_and_custom_functions: - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Can ONLY find COMPLEX symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - return default_and_custom_functions - - def _find_the_perfect_or_nearest_match_onnxfunction( - self, - node: torch.fx.Node, - default_and_custom_functions: list[registration.ONNXFunction], - onnx_args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - onnx_kwargs: dict[str, fx_type_utils.Argument], - ): - """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments. - - Args: - default_and_custom_functions: The list includes overloaded functions, with - custom ones appearing after the default ones. - onnx_args: Arguments organized in PyTorch inputs way. - onnx_kwargs: Keyword arguments organized in PyTorch inputs way. - - Returns: - Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. - Raises: - RuntimeError: If there are no overloaded functions available for the given FX node. - """ - overload_match_ranking: dict[registration.ONNXFunction, int | None] = {} - - # Iterate the overloaded functions in reverse order to prioritize the custom ones - # over the default ones, and find the perfect match. - for symbolic_function in reversed(default_and_custom_functions): - function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) - - # NOTE: 1. If the perfect match is found, return the function - if function_opschema.perfect_match_inputs(onnx_args, onnx_kwargs): - return symbolic_function.onnx_function - # Record the match score for the nearest match if it's not the perfect match - overload_match_ranking[symbolic_function] = function_opschema.match_score - - # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates - # If there is no nearest match, raise an error - overload_match_ranking = { - k: v for k, v in overload_match_ranking.items() if v is not None - } - if not overload_match_ranking: - # If there are no overloaded functions available for the given FX node, raise an - # unsupported error - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," - f"which should be registered under {node.target}.", - ) - - # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one - # that is custom first. If there are multiple custom ones, we will choose the one - # that is added lastly in the list. - symbolic_function_list: list[registration.ONNXFunction] = sorted( - overload_match_ranking, - key=lambda k: ( - overload_match_ranking[k], - k.is_custom, - default_and_custom_functions.index(k), - ), - reverse=True, - ) - return symbolic_function_list[0].onnx_function - - def _get_aten_name(self, node: torch.fx.Node) -> registration.OpName: - """Get the OpName from the target. - - Args: - node: The TorchFX node to get the aten name for. - - Returns: - The internal op name within dataclass: registration.OpName. - """ - if node.target == operator.getitem: - return registration.OpName.from_name_parts( - namespace="aten", op_name="getitem" - ) - if isinstance(node.target, torch._ops.OpOverloadPacket): - # aten::sym_size is the only OverloadPacket that we support. - # schema: aten::sym_size(Tensor self, int dim) -> Tensor - if node.target != torch.ops.aten.sym_size: - raise RuntimeError( - f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!", - ) - # TODO(titaiwang): aten::sym_size has overload, but fx graph is using - # overloadpacket for some reasons. - # https://github.com/pytorch/pytorch/issues/97201 - aten_op_default = node.target.default - return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return] - - if isinstance(node.target, types.BuiltinFunctionType): - # Make sure it's symint/symfloat consuming builtin ops. - for node_arg in node.args: - if (not isinstance(node_arg, (torch.fx.Node, int, float))) or ( - isinstance(node_arg, torch.fx.Node) - and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"]) - ): - raise RuntimeError( - f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target}," - " only int/float/SymInt/SymFloat is supported with built-in ops!", - ) - return registration.OpName.from_builtin_function(node.target) - - if isinstance(node.target, torch._ops.OpOverload): - return registration.OpName.from_op_overload(op_overload=node.target) - - # Unexpected target, raise error. - raise RuntimeError(f"Unknown call_function target: {node.target}") - - def get_function_overloads( - self, - node: torch.fx.Node, - ) -> list[registration.ONNXFunction]: - """Get the function overloads from the registry. - - Args: - node: The node to get the function overloads for. - - Returns: - The list contains ONNXFunctions, starting with the default ones and - followed by any custom ones. - """ - - internal_opname: registration.OpName = self._get_aten_name(node=node) - - # If the ATen/Custom operators are not registered, the group will be None. - # And non-registered ATen/Custom operators will trigger error in the next step. - function_group: list[registration.ONNXFunction] | None = None - - function_group = self.onnx_registry.get_op_functions( - namespace=internal_opname.namespace, - op_name=internal_opname.op_name, - overload=internal_opname.overload, - ) - - # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. - if function_group is None: - function_group = self.onnx_registry.get_op_functions( - namespace=internal_opname.namespace, - op_name=internal_opname.op_name, - overload=None, - ) - if function_group is not None: - op_full_name = internal_opname.qualified_name() - - if function_group is not None: - # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. - function_group = self._filter_or_keep_complex(node, function_group) - return function_group # type: ignore[return-value] - - op_full_name = internal_opname.qualified_name() - raise RuntimeError( - f"Cannot find symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - - -class _OnnxSchemaChecker: - """ - The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema. - - It provides methods to check for input compatibility based on the OpSchema. It also - provides a matching score to indicate how well the OpSchema matches the input and - kwargs types. A function will be evaluated as perfect match, nearest match eligible, - or no match. - - Here are some common examples in categories: - - 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as - the OpSchema. The types of inputs and attributes are exactly the same as the - OpSchema. - - ```python - inputs = (Tensor[2, 3], Tensor[2, 3]) - attributes = {"alpha": 1.0} - - - @torch_op("aten::op") - def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... - ``` - Result: Perfect match. - - 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, - the input can't be ignored. None must be provided. - - ```python - inputs = (Tensor([2, 3]), None) - attributes = {} - - aten_op(X: TTensor, Y: Optional[INT64]): - ... - ``` - Result: Perfect match. - Real example: `aten::convolution`. - - 3. [NOTE: Different attributes]: If an attribute is provided with value, it's - a must to match the attribute in function signature. - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a":1, "b":2} - - aten_op(X: TTensor, a: int): - ... - ``` - Result: No match. - Real example: `aten::div` vs `aten::div.Tensor_mode`. - - 4. [NOTE: Default attributes]: Default attribute will fill in the value into - inputs/attributes. - ```python - inputs = (Tensor([2, 3]),) - attributes = {} - - aten_op(X: TTensor, a: int = 3): - ... - ``` - Result: Perfect match. - Real example: `aten::clone` - - 5. [NOTE: Ignore attribute with None value]: The attributes with None value - will be ignored in matching. - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a": None} - - aten_op(X: TTensor): - ... - ``` - Result: Perfect match. - - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a": None} - - aten_op(X: TTensor, a: int = 3): - ... - ``` - Result: Nearest match eligible. - - Real example: `aten::div` vs `aten::div.Tensor_mode`. - - Attributes: - onnxfunction: The OnnxFunction. - param_schema: The parameter schema defined in the OnnxFunction. - op_schema: The ONNX OpSchema. - type_constraints: The type constraints defined in the OpSchema. - attributes: The attributes defined in the OpSchema. - _matching_score: The matching score of the OnnxSchemaChecker . - - """ - - def __init__( - self, - onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - ): - """Initialize the OnnxSchemaChecker . - - Args: - onnxfunction: The OnnxFunction. - """ - self.onnxfunction = onnxfunction - self.param_schema = self.onnxfunction.param_schemas() - op_schema = self.onnxfunction.op_schema - # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`. - # However their base class would. Hence return type is annotated as Optional[OpSchema]. - assert op_schema is not None - self.op_schema = op_schema - self.type_constraints = { - # "T": {"tensor(int64)"} - constraint.type_param_str: set(constraint.allowed_type_strs) - for constraint in self.op_schema.type_constraints - } - self.attributes = self.op_schema.attributes - self._matching_score: int | None = None - - @property - def match_score(self) -> int | None: - """The matching score of the OnnxSchemaChecker . - - If this remains None, it means the matching score has not been calculated, - and it's not a nearest match candidate. - - Returns: - The matching score of the OnnxSchemaChecker . - """ - return self._matching_score - - def perfect_match_inputs( - self, - args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - kwargs: dict[str, fx_type_utils.Argument], - ) -> bool: - """Check if the inputs perfectly match the OpSchema requirements. - - The definition of perfect match is that the input types are all in the type - constraints and the number of inputs matches the number of inputs in the - OpSchema. - - Checking steps: - 1. The function signature matches the inputs number, and attribute names. - 2. The input/attribute types are all in the type constraints. - - A function should at least pass the first step to be eligible for the - nearest matching. - - Args: - args: The input arguments organized in PyTorch inputs way. - kwargs: The input keyword arguments organized in PyTorch inputs way. - - Returns: - True if the inputs match the requirements, False otherwise. - """ - - # NOTE: OnnxFunction does not have the same function signature as the original - # PyTorch operator. We need to separate the input/attributes from the arguments. - ( - function_inputs, - function_attributes, - ) = self._separate_input_attributes_from_arguments( - self.param_schema, - args, - kwargs, - fill_defaults=True, # fill defaults for optional arguments to match - ) - # NOTE: 1. Check if the input number and attribute names match the - # OpSchema. If it's not, we know the function is not eligible to be a perfect - # match, nor a nearest match. - # We use is_perfect_match to postpone the return value to the end - # of the function, as we want to log all the mismatch info. - is_perfect_match = True - if len(function_inputs) != len(self.op_schema.inputs): - logger.info( - "Actual %d vs expected %d", - len(function_inputs), - len(self.op_schema.inputs), - ) - logger.info("The function is not a nearest match candidate.") - is_perfect_match = False - - if set(function_attributes) != set(self.attributes): - logger.info("The function is not a nearest match candidate.") - is_perfect_match = False - - # If it's already not a perfect match, we can return False directly. Further - # checking is only for the functions that are eligible for nearest match. - if not is_perfect_match: - return False - - # NOTE: 2. The dtypes of inputs and attributes should be in the - # type constraints of the OpSchema. If they are not, we know the function is not - # eligible to be a perfect match, but can be a nearest match candidate. - for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs): - torch_input_compatible_types = _find_onnx_data_type(torch_input) - allowed_types = self.type_constraints[schema_input.type_str] - if not allowed_types.intersection(torch_input_compatible_types) and not any( - fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) - for onnx_type_str in allowed_types - ): - # If torch_input_compatible_types isn't in allowed_types - # of this input defined in the OpSchema, we know the function - # and the input are not compatible - logger.info( - "Actual %s vs\nExpected %s", - torch_input_compatible_types, - allowed_types, - ) - is_perfect_match = False - - for attribute_name, attribute in function_attributes.items(): - if not self._match_onnx_attribute_type(attribute_name, attribute): - # If the attribute type of the OpSchema and the attribute type don't match, - # we know the function and the input are not compatible - logger.info( - "Actual %s vs\nExpected %s", - type(attribute), - self.attributes[attribute_name].type, - ) - is_perfect_match = False - - # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. - self._record_matching_score(function_inputs, function_attributes) - logger.info("match score: %d", self.match_score) - return is_perfect_match - - def _match_onnx_attribute_type( - self, - attribute_name: str, - attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor, - is_sequence: bool = False, - ) -> bool: - if isinstance(attribute, (int, float, bool, str)): - attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( - type(attribute), is_sequence=is_sequence - ) - if attribute_onnx_type != self.attributes[attribute_name].type: - return False - # If the attribute is an empty list, we don't know the type of the list - # so it's a mismatch - elif isinstance(attribute, (list, tuple)) and attribute: - return self._match_onnx_attribute_type( - attribute_name, attribute[0], is_sequence=True - ) - else: - # NOTE: Unrecognized attribute type - return False - return True - - def _record_matching_score( - self, - inputs: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - attributes: dict[str, fx_type_utils.Argument], - ): - """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. - - Only the functions which have the same number of inputs and attributes as the - OpSchema are eligible to be a nearest match candidate. Thus, we don't need to - check the length of inputs and attributes here, and only check the types of - inputs and attributes. - - How the matchsing score is calculated: - score += 1 if one input/attribute type is in the type constraints. - - Limitations: - None/NoeType/[] could result in zero matches, and the same score of overloads. - - Args: - inputs: The input arguments. - attributes: The input keyword arguments. - - Returns: - True if the inputs match the requirements, False otherwise. - """ - self._matching_score = 0 - # If they have different length of arguments, the score would be lower to those - # functions which have the same length of arguments. - for schema_input, torch_input in zip(self.op_schema.inputs, inputs): - torch_input_compatible_types = _find_onnx_data_type(torch_input) - allowed_types = self.type_constraints[schema_input.type_str] - if allowed_types.intersection(torch_input_compatible_types): - # If torch_input_compatible_types is in allowed_types - # of this input defined in the OpSchema, we know the function - # and the input are compatible - self._matching_score += 1 - # NOTE: The penalty is applied to those functions which have different attributes. - for attribute_name, attribute_proto in self.attributes.items(): - attribute = attributes[attribute_name] - attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( - type(attribute) - ) - if attribute_onnx_type != attribute_proto.type: - # If the attribute type of the OpSchema and the attribute type don't match, - # we know the function and the input are not compatible - self._matching_score -= 1 - - # NOTE: Referenced from onnxscript internal function. - # Importing this function makes the code less robust, as it is not a public API. - - def _separate_input_attributes_from_arguments( - self, - param_schemas: Sequence[onnxscript.values.ParamSchema], - args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - kwargs: dict[str, fx_type_utils.Argument], - fill_defaults: bool = True, - ) -> tuple[list[Any], dict[str, Any]]: - """Separate Python args and kwargs into ONNX inputs and attributes. - - Extra_kwargs are ignored if their values are None. For example, if the - OpSchema has an attribute "rounding_mode" and the caller provides - "rounding_mode=None", the attribute "rounding_mode" will not be included - in the returned attributes when the OnnxFunction signature doesn't have - "rounding_mode" as an attribute. - - Args: - param_schemas: The parameter schemas of an Op or a OnnxFunction. - args: The Python positional arguments supplied by the caller. - kwargs: The Python keyword arguments supplied by the caller. - fill_defaults: Whether to fill the default values for attributes. - - Returns: - A tuple of two elements: - - A list of ONNX inputs. - - An dictionary of ONNX attribute names and values. - - Raises: - TypeError: When allow_extra_kwargs is False and there are unknown kwargs. - TypeError: When a required input is not provided. - """ - # args, kwargs and param_schemas should be all in order - # user may not specify all inputs or attributes - - import onnx - - onnx_inputs: list[Any] = [] - onnx_attributes: dict[str, Any] = {} - # NOTE: We need to copy kwargs because we will mutate it - copy_kwargs = kwargs.copy() - for i, param in enumerate(param_schemas): - if param.is_variadic_input: - # Exhaust all remaining args - onnx_inputs.extend(args[i:]) - args = [] - continue - if i < len(args): - if param.is_input: - onnx_inputs.append(args[i]) - else: - onnx_attributes[param.name] = args[i] - elif param.name in copy_kwargs: - if param.is_input: - # Move the input from kwargs to inputs - onnx_inputs.append(copy_kwargs[param.name]) - copy_kwargs.pop(param.name) - else: - onnx_attributes[param.name] = copy_kwargs[param.name] - elif ( - param.is_attribute - and self.attributes[param.name].default_value.type - != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] - ): - # User did not provide the attribute - if fill_defaults: - onnx_attributes[param.name] = param.default - # optional input - elif param.is_input: - if fill_defaults: - onnx_inputs.append(None) - - # NOTE: Pick up extra kwargs if it's not None. None is not expected - # as an attribute value in torchlib. - for k, v in copy_kwargs.items(): - if k not in onnx_attributes and v is not None: - onnx_attributes[k] = v - return onnx_inputs, onnx_attributes - - -def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool: - """Check if the node has complex dtype recursively.""" - if ( - isinstance(arg, torch.fx.Node) - and "val" in arg.meta - and isinstance(arg.meta["val"], torch.Tensor) - and torch.is_complex(arg.meta["val"]) - ): - return True - elif isinstance(arg, list): - for item in arg: - return _is_arg_with_complex_dtype(item) - return False - - -def _find_onnx_data_type( - torch_input: fx_type_utils.TensorLike - | str - | int - | float - | bool - | list - | tuple - | complex - | None, -) -> set[str]: - """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string.""" - if ( - isinstance(torch_input, fx_type_utils.TensorLike) - and torch_input.dtype is not None - ): - return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype) - if isinstance(torch_input, (int, float, bool, str, complex)): - return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input)) - if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor] - the_first_non_none_item = next( - (item for item in torch_input if item is not None), None - ) - set_dtype = _find_onnx_data_type(the_first_non_none_item) - if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input): - # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type)) - return {f"seq({dtype})" for dtype in set_dtype} - else: - # constant list of non-tensor type - return set_dtype - if ( - torch_input is None - or ( - isinstance(torch_input, fx_type_utils.TensorLike) - and torch_input.dtype is None - ) - or (isinstance(torch_input, (list, tuple)) and not torch_input) - ): - # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check - # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case. - return set() - - raise RuntimeError(f"Unknown input type from input: {torch_input}") diff --git a/torch/onnx/_internal/fx/registration.py b/torch/onnx/_internal/fx/registration.py deleted file mode 100644 index ec6fc638e3f2a..0000000000000 --- a/torch/onnx/_internal/fx/registration.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for handling ATen to ONNX functions registration.""" - -from __future__ import annotations - -import dataclasses -from typing import TYPE_CHECKING - - -# We can only import onnx from this module in a type-checking context to ensure that -# 'import torch.onnx' continues to work without having 'onnx' installed. We fully -# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). -if TYPE_CHECKING: - import types - - import onnxscript # type: ignore[import] - - import torch._ops - - -@dataclasses.dataclass(frozen=True, eq=True) -class ONNXFunction: - """A wrapper of onnx-script function. - - op_full_name: The qualified name of the function. In the form of '::.'. - onnx_function: The onnx-script function from torchlib. - is_custom: Whether the function is a custom function. - is_complex: Whether the function is a function that handles complex valued inputs. - - """ - - onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction - op_full_name: str - is_custom: bool = False - is_complex: bool = False - - -@dataclasses.dataclass(frozen=True, eq=True) -class OpName: - """A class representing an operator name in internal ONNX converter.""" - - namespace: str - op_name: str - overload: str - - @classmethod - def from_name_parts( - cls, namespace: str, op_name: str, overload: str | None = None - ) -> OpName: - # NOTE: in PyTorch, the overload could be unprovided to indicate the - # default overload - if overload is None or overload == "": - overload = "default" - return cls(namespace, op_name, overload) - - @classmethod - def from_qualified_name(cls, qualified_name: str) -> OpName: - """When the name is ::[.]""" - namespace, opname_overload = qualified_name.split("::") - op_name, *overload = opname_overload.split(".", 1) - overload = overload[0] if overload else "default" - return cls(namespace, op_name, overload) - - @classmethod - def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName: - return cls.from_qualified_name(op_overload.name()) - - @classmethod - def from_builtin_function( - cls, builtin_function: types.BuiltinFunctionType - ) -> OpName: - """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. - - FX graph uses built-in functions to calculate sympy expression. This function - is used to get the OpName from a builtin function. - - Args: - builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc. - - Returns: - OpName: _description_ - """ - op = builtin_function.__name__ # add, sub, etc. - module = builtin_function.__module__ # _operators or math - return cls.from_qualified_name(module + "::" + op) - - def qualified_name(self) -> str: - return f"{self.namespace}::{self.op_name}.{self.overload}" From abeae997a35b1920a45be9c26eff7474f2c6c5dd Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 15 Jul 2025 21:08:25 +0000 Subject: [PATCH 078/457] Use brew suggested miniconda install command (#158347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use ```brew install --cask miniconda``` as specified by https://formulae.brew.sh/cask/miniconda Forward fix After: https://github.com/pytorch/pytorch/pull/156898#issuecomment-3074207175 Seeing in CI: ``` Run if [[ -n "$REINSTALL_BREW_MINICONDA" ]]; then ==> Caveats Please run the following to setup your shell: conda init "$(basename "${SHELL}")" Alternatively, manually add the following to your shell init: eval "$(conda "shell.$(basename "${SHELL}")" hook)" ==> Downloading https://repo.anaconda.com/miniconda/Miniconda3-py313_25.5.1-0-MacOSX-arm64.sh Already downloaded: /Users/ec2-user/Library/Caches/Homebrew/downloads/2e356e8b147647692e4da77ce4c0c14eefee65ec86f29cc7e8c21a26ac9397ca--Miniconda3-py313_25.5.1-0-MacOSX-arm64.sh ==> Installing Cask miniconda ==> Running installer script 'Miniconda3-py313_25.5.1-0-MacOSX-arm64.sh' PREFIX=/opt/homebrew/Caskroom/miniconda/base Unpacking payload ... entry_point.py:256: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior. entry_point.py:256: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior. Installing base environment... Preparing transaction: ...working... done Executing transaction: ...working... done entry_point.py:256: DeprecationWarning: Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata. Use the filter argument to control this behavior. installation finished. ==> Linking Binary 'conda' to '/opt/homebrew/bin/conda' 🍺 miniconda was successfully installed! ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158347 Approved by: https://github.com/seemethere --- .github/workflows/_mac-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 8498ba5a09323..550053de73256 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -281,7 +281,7 @@ jobs: continue-on-error: true run: | if [[ -n "$REINSTALL_BREW_MINICONDA" ]]; then - brew install miniconda + brew install --cask miniconda fi - name: Clean up disk space From 05dfd312cfbfdecc6cb1e7d1d0bb4ee18370ae7e Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 15 Jul 2025 21:14:14 +0000 Subject: [PATCH 079/457] [3/n] Remove references to TorchScript in PyTorch docs (#158315) Summary: - cpp_index.rst - fx.md - jit_builtin_functions.rst - jit_python_reference.md - jit_unsupported.md cpu_threading large_scale_deployment Test Plan: CI Rollback Plan: Differential Revision: D78309320 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158315 Approved by: https://github.com/svekars, https://github.com/zhxchen17 --- docs/source/cpp_index.rst | 18 +- docs/source/fx.md | 3 +- docs/source/jit.rst | 2 +- docs/source/jit_builtin_functions.rst | 6 +- docs/source/jit_python_reference.md | 432 +---------- docs/source/jit_unsupported.md | 79 +- docs/source/notes/cpu_threading_runtimes.svg | 208 ------ .../cpu_threading_torchscript_inference.rst | 158 +--- .../cpu_threading_torchscript_inference.svg | 681 ------------------ docs/source/notes/large_scale_deployments.rst | 42 +- docs/source/package.md | 15 - 11 files changed, 18 insertions(+), 1626 deletions(-) delete mode 100644 docs/source/notes/cpu_threading_runtimes.svg delete mode 100644 docs/source/notes/cpu_threading_torchscript_inference.svg diff --git a/docs/source/cpp_index.rst b/docs/source/cpp_index.rst index 23302286f0e3c..37571b9c60bc2 100644 --- a/docs/source/cpp_index.rst +++ b/docs/source/cpp_index.rst @@ -7,20 +7,6 @@ C++ PyTorch provides several features for working with C++, and it’s best to choose from them based on your needs. At a high level, the following support is available: -TorchScript C++ API --------------------- -`TorchScript `__ allows PyTorch models defined in Python to be serialized and then loaded and run in C++ capturing the model code via compilation or tracing its execution. You can learn more in the `Loading a TorchScript Model in C++ tutorial `__. This means you can define your models in Python as much as possible, but subsequently export them via TorchScript for doing no-Python execution in production or embedded environments. The TorchScript C++ API is used to interact with these models and the TorchScript execution engine, including: - -* Loading serialized TorchScript models saved from Python -* Doing simple model modifications if needed (e.g. pulling out submodules) -* Constructing the input and doing preprocessing using C++ Tensor API - -Extending PyTorch and TorchScript with C++ Extensions ------------------------------------------------------- -TorchScript can be augmented with user-supplied code through custom operators and custom classes. -Once registered with TorchScript, these operators and classes can be invoked in TorchScript code run from -Python or from C++ as part of a serialized TorchScript model. The `Extending TorchScript with Custom C++ Operators `__ tutorial walks through interfacing TorchScript with OpenCV. In addition to wrapping a function call with a custom operator, C++ classes and structs can be bound into TorchScript through a pybind11-like interface which is explained in the `Extending TorchScript with Custom C++ Classes `__ tutorial. - Tensor and Autograd in C++ --------------------------- Most of the tensor and autograd operations in PyTorch Python API are also available in the C++ API. These include: @@ -31,9 +17,7 @@ Most of the tensor and autograd operations in PyTorch Python API are also availa Authoring Models in C++ ------------------------ -The "author in TorchScript, infer in C++" workflow requires model authoring to be done in TorchScript. -However, there might be cases where the model has to be authored in C++ (e.g. in workflows where a Python -component is undesirable). To serve such use cases, we provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. +We provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. * For an overview of the PyTorch C++ model authoring and training API, please see: https://pytorch.org/cppdocs/frontend.html * For a detailed tutorial on how to use the API, please see: https://pytorch.org/tutorials/advanced/cpp_frontend.html diff --git a/docs/source/fx.md b/docs/source/fx.md index 8b60c80649661..831534606abe0 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -44,8 +44,7 @@ Your transform will take in a {class}`torch.nn.Module`, acquire a {class}`Graph` from it, do some modifications, and return a new {class}`torch.nn.Module`. You should think of the {class}`torch.nn.Module` that your FX transform returns as identical to a regular {class}`torch.nn.Module` -- you can pass it to another -FX transform, you can pass it to TorchScript, or you can -run it. Ensuring that the inputs and outputs of your FX transform are a +FX transform, or you can run it. Ensuring that the inputs and outputs of your FX transform are a {class}`torch.nn.Module` will allow for composability. ```{note} diff --git a/docs/source/jit.rst b/docs/source/jit.rst index c5ba9063a50c8..31c5c4dbf8249 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -203,7 +203,7 @@ See :ref:`jit_unsupported` for a list of unsupported PyTorch functions and modul Python Functions and Modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Many of Python's `built-in functions `_ are supported in TorchScript. -The :any:`math` module is also supported (see :ref:`math-module` for details), but no other Python modules +The :any:`math` module is also supported, but no other Python modules (built-in or third party) are supported. diff --git a/docs/source/jit_builtin_functions.rst b/docs/source/jit_builtin_functions.rst index a6cdb8c478704..c08e0739266a9 100644 --- a/docs/source/jit_builtin_functions.rst +++ b/docs/source/jit_builtin_functions.rst @@ -3,8 +3,8 @@ TorchScript Builtins ==================== -This is a full reference of functions and Tensor methods accessible in TorchScript - -.. contents:: :local: +.. warning:: + TorchScript is deprecated, please use + `torch.export `__ instead. .. automodule:: torch.jit.supported_ops diff --git a/docs/source/jit_python_reference.md b/docs/source/jit_python_reference.md index 1d2b5c78a894f..edcd93bb7b0d9 100644 --- a/docs/source/jit_python_reference.md +++ b/docs/source/jit_python_reference.md @@ -2,431 +2,7 @@ # Python Language Reference Coverage -This is a 1:1 mapping of the features listed in https://docs.python.org/3/reference/ and their -support in TorchScript. The categorizations are as follows: - -```{list-table} -:widths: 40 40 20 -:header-rows: 1 - -* - Section - - Status - - Note -* - [1. Introduction](https://docs.python.org/3/reference/introduction.html) - - Not Relevant - - -* - [1.1. Alternate Implementations](https://docs.python.org/3/reference/introduction.html#alternate-implementations) - - Not Relevant - - -* - [1.2. Notation](https://docs.python.org/3/reference/introduction.html#notation) - - Not Relevant - - -* - [2. Lexical analysis](https://docs.python.org/3/reference/lexical_analysis.html#) - - Not Relevant - - -* - [2.1. Line structure](https://docs.python.org/3/reference/lexical_analysis.html#line-structure) - - Not Relevant - - -* - [2.1.1. Logical lines](https://docs.python.org/3/reference/lexical_analysis.html#logical-lines) - - Not Relevant - - -* - [2.1.2. Physical lines](https://docs.python.org/3/reference/lexical_analysis.html#physical-lines) - - Supported - - -* - [2.1.3. Comments](https://docs.python.org/3/reference/lexical_analysis.html#comments) - - Supported - - -* - [2.1.4. Encoding declarations](https://docs.python.org/3/reference/lexical_analysis.html#encoding-declarations) - - Not Supported - - TorchScript explicitly don't support unicode -* - [2.1.5. Explicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#explicit-line-joining) - - Supported - - -* - [2.1.6. Implicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#implicit-line-joining) - - Supported - - -* - [2.1.7. Blank lines](https://docs.python.org/3/reference/lexical_analysis.html#blank-lines) - - Supported - - -* - [2.1.8. Indentation](https://docs.python.org/3/reference/lexical_analysis.html#indentation) - - Supported - - -* - [2.1.9. Whitespace between tokens](https://docs.python.org/3/reference/lexical_analysis.html#whitespace-between-tokens) - - Not Relevant - - -* - [2.2. Other tokens](https://docs.python.org/3/reference/lexical_analysis.html#other-tokens) - - Not Relevant - - -* - [2.3. Identifiers and keywords](https://docs.python.org/3/reference/lexical_analysis.html#identifiers) - - Supported - - -* - [2.3.1. Keywords](https://docs.python.org/3/reference/lexical_analysis.html#keywords) - - Supported - - -* - [2.3.2. Reserved classes of identifiers](https://docs.python.org/3/reference/lexical_analysis.html#reserved-classes-of-identifiers) - - Supported - - -* - [2.4. Literals](https://docs.python.org/3/reference/lexical_analysis.html#literals) - - Not Relevant - - -* - [2.4.1. String and Bytes literals](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals) - - Supported - - -* - [2.4.2. String literal concatenation](https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation) - - Supported - - -* - [2.4.3. Formatted string literals](https://docs.python.org/3/reference/lexical_analysis.html#formatted-string-literals) - - Partially Supported - - -* - [2.4.4. Numeric literals](https://docs.python.org/3/reference/lexical_analysis.html#numeric-literals) - - Supported - - -* - [2.4.5. Integer literals](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals) - - Supported - - -* - [2.4.6. Floating point literals](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) - - Supported - - -* - [2.4.7. Imaginary literals](https://docs.python.org/3/reference/lexical_analysis.html#imaginary-literals) - - Not Supported - - -* - [2.5. Operators](https://docs.python.org/3/reference/lexical_analysis.html#operators) - - Partially Supported - - Not supported: ``<<``, ``>>``, ``:=`` -* - [2.6. Delimiters](https://docs.python.org/3/reference/lexical_analysis.html#delimiters) - - Partially Supported - - Not supported: ``**=``, ``<<=``, ``>>=``, ``%=``, ``^=``, ``@=``, ``&=``, ``//=``, ``%`` operator for some types (e.g. ``str``\ ) -* - [3. Data model](https://docs.python.org/3/reference/datamodel.html#) - - Not Relevant - - -* - [3.1. Objects, values and types](https://docs.python.org/3/reference/datamodel.html#objects-values-and-types) - - Not Relevant - - -* - [3.2. The standard type hierarchy](https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy) - - Partially Supported - - Not supported: NotImplemented, Ellipsis, numbers.Complex, bytes, byte arrays, sets, frozen sets, generators, coroutines, async generators, modules, I/O objects, internal objects, slice objects ( though slicing is supported), classmethod -* - [3.3. Special method names](https://docs.python.org/3/reference/datamodel.html#special-method-names) - - Supported - - -* - [3.3.1. Basic customization](https://docs.python.org/3/reference/datamodel.html#basic-customization) - - Partially Supported - - Not supported: ``__new__`` , ``__del__`` , ``__bytes__`` , ``__format__`` , ``__hash__`` , -* - [3.3.2. Customizing attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-attribute-access) - - Not Supported - - -* - [3.3.2.1. Customizing module attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-module-attribute-access) - - Not Supported - - -* - [3.3.2.2. Implementing Descriptors](https://docs.python.org/3/reference/datamodel.html#implementing-descriptors) - - Not Supported - - -* - [3.3.2.3. Invoking Descriptors](https://docs.python.org/3/reference/datamodel.html#invoking-descriptors) - - Not Supported - - -* - [3.3.2.4. __slots__](https://docs.python.org/3/reference/datamodel.html#slots) - - Not Supported - - -* - [3.3.2.4.1. Notes on using __slots__](https://docs.python.org/3/reference/datamodel.html#notes-on-using-slots) - - Not Supported - - -* - [3.3.3. Customizing class creation](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation) - - Not Supported - - -* - [3.3.3.1. Metaclasses](https://docs.python.org/3/reference/datamodel.html#metaclasses) - - Not Supported - - -* - [3.3.3.2. Resolving MRO entries](https://docs.python.org/3/reference/datamodel.html#resolving-mro-entries) - - Not Supported - - [`super()`` is not supported -* - [3.3.3.3. Determining the appropriate metaclass](https://docs.python.org/3/reference/datamodel.html#determining-the-appropriate-metaclass) - - Not relevant - - -* - [3.3.3.4. Preparing the class namespace](https://docs.python.org/3/reference/datamodel.html#preparing-the-class-namespace) - - Not relevant - - -* - [3.3.3.5. Executing the class body](https://docs.python.org/3/reference/datamodel.html#executing-the-class-body) - - Not relevant - - -* - [3.3.3.6. Creating the class object](https://docs.python.org/3/reference/datamodel.html#creating-the-class-object) - - Not relevant - - -* - [3.3.3.7. Uses for metaclasses](https://docs.python.org/3/reference/datamodel.html#uses-for-metaclasses) - - Not relevant - - -* - [3.3.4. Customizing instance and subclass checks](https://docs.python.org/3/reference/datamodel.html#customizing-instance-and-subclass-checks) - - Not Supported - - -* - [3.3.5. Emulating generic types](https://docs.python.org/3/reference/datamodel.html#emulating-generic-types) - - Not Supported - - -* - [3.3.6. Emulating callable objects](https://docs.python.org/3/reference/datamodel.html#emulating-callable-objects) - - Supported - - -* - [3.3.7. Emulating container types](https://docs.python.org/3/reference/datamodel.html#emulating-container-types) - - Partially Supported - - Some magic methods not supported (e.g. ``__iter__`` ) -* - [3.3.8. Emulating numeric types](https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types) - - Partially Supported - - Magic methods with swapped operands not supported (``__r*__``) -* - [3.3.9. With Statement Context Managers](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers) - - Not Supported - - -* - [3.3.10. Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-method-lookup) - - Not relevant - - -* - [3.4. Coroutines](https://docs.python.org/3/reference/datamodel.html#coroutines) - - Not Supported - - -* - [3.4.1. Awaitable Objects](https://docs.python.org/3/reference/datamodel.html#awaitable-objects) - - Not Supported - - -* - [3.4.2. Coroutine Objects](https://docs.python.org/3/reference/datamodel.html#coroutine-objects) - - Not Supported - - -* - [3.4.3. Asynchronous Iterators](https://docs.python.org/3/reference/datamodel.html#asynchronous-iterators) - - Not Supported - - -* - [3.4.4. Asynchronous Context Managers](https://docs.python.org/3/reference/datamodel.html#asynchronous-context-managers) - - Not Supported - - -* - [4. Execution model](https://docs.python.org/3/reference/executionmodel.html#) - - Not Relevant - - -* - [4.1. Structure of a program](https://docs.python.org/3/reference/executionmodel.html#structure-of-a-program) - - Not Relevant - - -* - [4.2. Naming and binding](https://docs.python.org/3/reference/executionmodel.html#naming-and-binding) - - Not Relevant - - Names are bound at compile time in TorchScript -* - [4.2.1. Binding of names](https://docs.python.org/3/reference/executionmodel.html#binding-of-names) - - Not Relevant - - See ``global`` and ``nonlocal`` statements section -* - [4.2.2. Resolution of names](https://docs.python.org/3/reference/executionmodel.html#resolution-of-names) - - Not Relevant - - See ``global`` and ``nonlocal`` statements section -* - [4.2.3. Builtins and restricted execution](https://docs.python.org/3/reference/executionmodel.html#builtins-and-restricted-execution) - - Not Relevant - - -* - [4.2.4. Interaction with dynamic features](https://docs.python.org/3/reference/executionmodel.html#interaction-with-dynamic-features) - - Not Supported - - Python values cannot be captured -* - [4.3. Exceptions](https://docs.python.org/3/reference/executionmodel.html#exceptions) - - Partially Supported - - See ``try`` and ``raise`` statement section -* - [5. The import system](https://docs.python.org/3/reference/import.html) - - Not Relevant - - -* - [6. Expressions](https://docs.python.org/3/reference/expressions.html#) - - Not Relevant - - See expressions section -* - [6.1. Arithmetic conversions](https://docs.python.org/3/reference/expressions.html#arithmetic-conversions) - - Supported - - -* - [6.2. Atoms](https://docs.python.org/3/reference/expressions.html#atoms) - - Not Relevant - - -* - [6.2.1. Identifiers (Names)](https://docs.python.org/3/reference/expressions.html#atom-identifiers) - - Supported - - -* - [6.2.2. Literals](https://docs.python.org/3/reference/expressions.html#literals) - - Partially Supported - - [`bytesliteral``\ , ``imagnumber`` not supported -* - [6.2.3. Parenthesized forms](https://docs.python.org/3/reference/expressions.html#parenthesized-forms) - - Supported - - -* - [6.2.4. Displays for lists, sets and dictionaries](https://docs.python.org/3/reference/expressions.html#displays-for-lists-sets-and-dictionaries) - - Partially Supported - - Not supported: comprehension ifs, async iterators -* - [6.2.5. List displays](https://docs.python.org/3/reference/expressions.html#list-displays) - - Supported - - -* - [6.2.6. Set displays](https://docs.python.org/3/reference/expressions.html#set-displays) - - Not Supported - - -* - [6.2.7. Dictionary displays](https://docs.python.org/3/reference/expressions.html#dictionary-displays) - - Supported - - dict() constructor with kwargs doesn't work, dict comprehensions, dictionary unpacking -* - [6.2.8. Generator expressions](https://docs.python.org/3/reference/expressions.html#generator-expressions) - - Not Supported - - -* - [6.2.9. Yield expressions](https://docs.python.org/3/reference/expressions.html#yield-expressions) - - Not Supported - - -* - [6.2.9.1. Generator-iterator methods](https://docs.python.org/3/reference/expressions.html#generator-iterator-methods) - - Not Supported - - -* - [6.2.9.2. Examples](https://docs.python.org/3/reference/expressions.html#examples) - - Not Supported - - -* - [6.2.9.3. Asynchronous generator functions](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-functions) - - Not Supported - - -* - [6.2.9.4. Asynchronous generator-iterator methods](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-iterator-methods) - - Not Supported - - -* - [6.3. Primaries](https://docs.python.org/3/reference/expressions.html#primaries) - - Supported - - -* - [6.3.1. Attribute references](https://docs.python.org/3/reference/expressions.html#attribute-references) - - Supported - - -* - [6.3.2. Subscriptions](https://docs.python.org/3/reference/expressions.html#subscriptions) - - Supported - - -* - [6.3.3. Slicings](https://docs.python.org/3/reference/expressions.html#slicings) - - Partially Supported - - Tuple slicing with stride is not supported -* - [6.3.4. Calls](https://docs.python.org/3/reference/expressions.html#calls) - - Partially Supported - - Args unpack / kwargs unpack is not supported -* - [6.4. Await expression](https://docs.python.org/3/reference/expressions.html#await-expression) - - Not Supported - - -* - [6.5. The power operator](https://docs.python.org/3/reference/expressions.html#the-power-operator) - - Supported - - -* - [6.6. Unary arithmetic and bitwise operations](https://docs.python.org/3/reference/expressions.html#unary-arithmetic-and-bitwise-operations) - - Partially Supported - - Some bitwise operators are not implemented for primitive types (e.g. ``~x`` where ``x`` is an ``int`` is not currently supported) -* - [6.7. Binary arithmetic operations](https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations) - - Partially Supported - - See delimiters section -* - [6.8. Shifting operations](https://docs.python.org/3/reference/expressions.html#shifting-operations) - - Not Supported - - -* - [6.9. Binary bitwise operations](https://docs.python.org/3/reference/expressions.html#binary-bitwise-operations) - - Supported - - -* - [6.10. Comparisons](https://docs.python.org/3/reference/expressions.html#comparisons) - - Supported - - -* - [6.10.1. Value comparisons](https://docs.python.org/3/reference/expressions.html#value-comparisons) - - Partially Supported - - Dictionary equality checks are not currently supported -* - [6.10.2. Membership test operations](https://docs.python.org/3/reference/expressions.html#membership-test-operations) - - Partially Supported - - Not supported for TorchScript classes -* - [6.10.3. Identity comparisons](https://docs.python.org/3/reference/expressions.html#is-not) - - Supported - - -* - [6.11. Boolean operations](https://docs.python.org/3/reference/expressions.html#boolean-operations) - - Supported - - -* - [6.12. Conditional expressions](https://docs.python.org/3/reference/expressions.html#conditional-expressions) - - Supported - - -* - [6.13. Lambdas](https://docs.python.org/3/reference/expressions.html#lambda) - - Not Supported - - -* - [6.14. Expression lists](https://docs.python.org/3/reference/expressions.html#expression-lists) - - Partially Supported - - Iterable unpacking not supported -* - [6.15. Evaluation order](https://docs.python.org/3/reference/expressions.html#evaluation-order) - - Supported - - -* - [6.16. Operator precedence](https://docs.python.org/3/reference/expressions.html#operator-precedence) - - Supported - - -* - [7. Simple statements](https://docs.python.org/3/reference/simple_stmts.html#) - - Supported - - -* - [7.1. Expression statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements) - - Supported - - -* - [7.2. Assignment statements](https://docs.python.org/3/reference/simple_stmts.html#assignment-statements) - - Supported - - -* - [7.2.1. Augmented assignment statements](https://docs.python.org/3/reference/simple_stmts.html#augmented-assignment-statements) - - Partially Supported - - See delimiters section -* - [7.2.2. Annotated assignment statements](https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements) - - Supported - - -* - [7.3. The assert statement](https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement) - - Partially Supported - - Exception message is not customizable -* - [7.4. The pass statement](https://docs.python.org/3/reference/simple_stmts.html#the-pass-statement) - - Supported - - -* - [7.5. The del statement](https://docs.python.org/3/reference/simple_stmts.html#the-del-statement) - - Not Supported - - -* - [7.6. The return statement](https://docs.python.org/3/reference/simple_stmts.html#the-return-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.7. The yield statement](https://docs.python.org/3/reference/simple_stmts.html#the-yield-statement) - - Not Supported - - -* - [7.8. The raise statement](https://docs.python.org/3/reference/simple_stmts.html#the-raise-statement) - - Partially Supported - - Exception message is not customizable -* - [7.9. The break statement](https://docs.python.org/3/reference/simple_stmts.html#the-break-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.10. The continue statement](https://docs.python.org/3/reference/simple_stmts.html#the-continue-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.11. The import statement](https://docs.python.org/3/reference/simple_stmts.html#the-import-statement) - - Not Supported - - -* - [7.11.1. Future statements](https://docs.python.org/3/reference/simple_stmts.html#future-statements) - - Not Supported - - -* - [7.12. The global statement](https://docs.python.org/3/reference/simple_stmts.html#the-global-statement) - - Not Supported - - -* - [7.13. The nonlocal statement](https://docs.python.org/3/reference/simple_stmts.html#the-nonlocal-statement) - - Not Supported - - -* - [8. Compound statements](https://docs.python.org/3/reference/compound_stmts.html#) - - Irrelevant - - -* - [8.1. The if statement](https://docs.python.org/3/reference/compound_stmts.html#the-if-statement) - - Supported - - -* - [8.2. The while statement](https://docs.python.org/3/reference/compound_stmts.html#the-while-statement) - - Partially Supported - - while..else is not supported -* - [8.3. The for statement](https://docs.python.org/3/reference/compound_stmts.html#the-for-statement) - - Partially Supported - - for..else is not supported -* - [8.4. The try statement](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement) - - Not Supported - - -* - [8.5. The with statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement) - - Partially Supported - - [`__exit__`` is always called with ``exc_type``, ``exc_value``, and ``traceback`` set to None, even if an exception was raised, and ``__exit__``'s return value is ignored. -* - [8.6. Function definitions](https://docs.python.org/3/reference/compound_stmts.html#function-definitions) - - Not Supported - - -* - [8.7. Class definitions](https://docs.python.org/3/reference/compound_stmts.html#class-definitions) - - Not Supported - - -* - [8.8. Coroutines](https://docs.python.org/3/reference/compound_stmts.html#coroutines) - - Not Supported - - -* - [8.8.1. Coroutine function definition](https://docs.python.org/3/reference/compound_stmts.html#coroutine-function-definition) - - Not Supported - - -* - [8.8.2. The async for statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-for-statement) - - Not Supported - - -* - [8.8.3. The async with statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-with-statement) - - Not Supported - - -* - [9. Top-level components](https://docs.python.org/3/reference/toplevel_components.html#) - - Not Relevant - - -* - [9.1. Complete Python programs](https://docs.python.org/3/reference/toplevel_components.html#complete-python-programs) - - Not Relevant - - -* - [9.2. File input](https://docs.python.org/3/reference/toplevel_components.html#file-input) - - Not Relevant - - -* - [9.3. Interactive input](https://docs.python.org/3/reference/toplevel_components.html#interactive-input) - - Not Relevant - - -* - [9.4. Expression input](https://docs.python.org/3/reference/toplevel_components.html#expression-input) - - Not Relevant - - -``` +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file diff --git a/docs/source/jit_unsupported.md b/docs/source/jit_unsupported.md index 79a51c1651f34..be3ddaec12a72 100644 --- a/docs/source/jit_unsupported.md +++ b/docs/source/jit_unsupported.md @@ -2,80 +2,11 @@ # TorchScript Unsupported PyTorch Constructs -## Torch and Tensor Unsupported Attributes - -TorchScript supports most methods defined on `torch` and `torch.Tensor`, but we do not have full coverage. -Here are specific known ops and categories of ops which have diverging behavior between -Python and TorchScript. If you encounter something else that is not supported please -file a GitHub issue. Deprecated ops are not listed below. +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: ```{eval-rst} .. automodule:: torch.jit.unsupported_tensor_ops -``` - -### Functions Not Correctly Bound on Torch - -The following functions will fail if used in TorchScript, either because they -are not bound on `torch` or because Python expects a different schema than -TorchScript. - -- {func}`torch.tensordot` -- {func}`torch.nn.init.calculate_gain` -- {func}`torch.nn.init.eye_` -- {func}`torch.nn.init.dirac_` -- {func}`torch.nn.init.kaiming_normal_` -- {func}`torch.nn.init.orthogonal_` -- {func}`torch.nn.init.sparse` - -### Ops With Divergent Schemas Between Torch & Python - -The following categories of ops have divergent schemas: - -Functions which construct tensors from non-tensor inputs do not support the `requires_grad` -argument, except for `torch.tensor`. This covers the following ops: - -- {func}`torch.norm` -- {func}`torch.bartlett_window` -- {func}`torch.blackman_window` -- {func}`torch.empty` -- {func}`torch.empty_like` -- {func}`torch.empty_strided` -- {func}`torch.eye` -- {func}`torch.full` -- {func}`torch.full_like` -- {func}`torch.hamming_window` -- {func}`torch.hann_window` -- {func}`torch.linspace` -- {func}`torch.logspace` -- {func}`torch.normal` -- {func}`torch.ones` -- {func}`torch.rand` -- {func}`torch.rand_like` -- {func}`torch.randint_like` -- {func}`torch.randn` -- {func}`torch.randn_like` -- {func}`torch.randperm` -- {func}`torch.tril_indices` -- {func}`torch.triu_indices` -- {func}`torch.vander` -- {func}`torch.zeros` -- {func}`torch.zeros_like` - -The following functions require `dtype`, `layout`, `device` as parameters in TorchScript, -but these parameters are optional in Python. - -- {func}`torch.randint` -- {func}`torch.sparse_coo_tensor` -- {func}`torch.Tensor.to` - -## PyTorch Unsupported Modules and Classes - -TorchScript cannot currently compile a number of other commonly used PyTorch -constructs. Below are listed the modules that TorchScript does not support, and -an incomplete list of PyTorch classes that are not supported. For unsupported modules -we suggest using {meth}`torch.jit.trace`. - -- {class}`torch.nn.RNN` -- {class}`torch.nn.AdaptiveLogSoftmaxWithLoss` -- {class}`torch.autograd.Function` -- {class}`torch.autograd.enable_grad` +``` \ No newline at end of file diff --git a/docs/source/notes/cpu_threading_runtimes.svg b/docs/source/notes/cpu_threading_runtimes.svg deleted file mode 100644 index e36ec598f063c..0000000000000 --- a/docs/source/notes/cpu_threading_runtimes.svg +++ /dev/null @@ -1,208 +0,0 @@ - -image/svg+xml0102030400.51.01.52.02.5# ThreadsTime, s diff --git a/docs/source/notes/cpu_threading_torchscript_inference.rst b/docs/source/notes/cpu_threading_torchscript_inference.rst index e4e55dcf2bd35..8cac34c8c36fd 100644 --- a/docs/source/notes/cpu_threading_torchscript_inference.rst +++ b/docs/source/notes/cpu_threading_torchscript_inference.rst @@ -3,160 +3,6 @@ CPU threading and TorchScript inference ================================================= -PyTorch allows using multiple CPU threads during TorchScript model inference. -The following figure shows different levels of parallelism one would find in a -typical application: - -.. image:: cpu_threading_torchscript_inference.svg - :width: 75% - -One or more inference threads execute a model's forward pass on the given inputs. -Each inference thread invokes a JIT interpreter that executes the ops -of a model inline, one by one. A model can utilize a ``fork`` TorchScript -primitive to launch an asynchronous task. Forking several operations at once -results in a task that is executed in parallel. The ``fork`` operator returns a -``Future`` object which can be used to synchronize on later, for example: - -.. code-block:: python - - @torch.jit.script - def compute_z(x): - return torch.mm(x, self.w_z) - - @torch.jit.script - def forward(x): - # launch compute_z asynchronously: - fut = torch.jit._fork(compute_z, x) - # execute the next operation in parallel to compute_z: - y = torch.mm(x, self.w_y) - # wait for the result of compute_z: - z = torch.jit._wait(fut) - return y + z - - -PyTorch uses a single thread pool for the inter-op parallelism, this thread pool -is shared by all inference tasks that are forked within the application process. - -In addition to the inter-op parallelism, PyTorch can also utilize multiple threads -within the ops (`intra-op parallelism`). This can be useful in many cases, -including element-wise ops on large tensors, convolutions, GEMMs, embedding -lookups and others. - - -Build options -------------- - -PyTorch uses an internal ATen library to implement ops. In addition to that, -PyTorch can also be built with support of external libraries, such as MKL_ and MKL-DNN_, -to speed up computations on CPU. - -ATen, MKL and MKL-DNN support intra-op parallelism and depend on the -following parallelization libraries to implement it: - -* OpenMP_ - a standard (and a library, usually shipped with a compiler), widely used in external libraries; -* TBB_ - a newer parallelization library optimized for task-based parallelism and concurrent environments. - -OpenMP historically has been used by a large number of libraries. It is known -for a relative ease of use and support for loop-based parallelism and other primitives. - -TBB is used to a lesser extent in external libraries, but, at the same time, -is optimized for the concurrent environments. PyTorch's TBB backend guarantees that -there's a separate, single, per-process intra-op thread pool used by all of the -ops running in the application. - -Depending of the use case, one might find one or another parallelization -library a better choice in their application. - -PyTorch allows selecting of the parallelization backend used by ATen and other -libraries at the build time with the following build options: - -+------------+------------------------+-----------------------------+----------------------------------------+ -| Library | Build Option | Values | Notes | -+============+========================+=============================+========================================+ -| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | | -+------------+------------------------+-----------------------------+----------------------------------------+ -| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` | -+------------+------------------------+-----------------------------+----------------------------------------+ -| MKL-DNN | ``MKLDNN_CPU_RUNTIME`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` | -+------------+------------------------+-----------------------------+----------------------------------------+ - -It is recommended not to mix OpenMP and TBB within one build. - -Any of the ``TBB`` values above require ``USE_TBB=1`` build setting (default: OFF). -A separate setting ``USE_OPENMP=1`` (default: ON) is required for OpenMP parallelism. - -Runtime API ------------ - -The following API is used to control thread settings: - -+------------------------+-----------------------------------------------------------+---------------------------------------------------------+ -| Type of parallelism | Settings | Notes | -+========================+===========================================================+=========================================================+ -| Inter-op parallelism | ``at::set_num_interop_threads``, | Default number of threads: number of CPU cores. | -| | ``at::get_num_interop_threads`` (C++) | | -| | | | -| | ``set_num_interop_threads``, | | -| | ``get_num_interop_threads`` (Python, :mod:`torch` module) | | -+------------------------+-----------------------------------------------------------+ | -| Intra-op parallelism | ``at::set_num_threads``, | | -| | ``at::get_num_threads`` (C++) | | -| | ``set_num_threads``, | | -| | ``get_num_threads`` (Python, :mod:`torch` module) | | -| | | | -| | Environment variables: | | -| | ``OMP_NUM_THREADS`` and ``MKL_NUM_THREADS`` | | -+------------------------+-----------------------------------------------------------+---------------------------------------------------------+ - -For the intra-op parallelism settings, ``at::set_num_threads``, ``torch.set_num_threads`` always take precedence -over environment variables, ``MKL_NUM_THREADS`` variable takes precedence over ``OMP_NUM_THREADS``. - -Tuning the number of threads ----------------------------- - -The following simple script shows how a runtime of matrix multiplication changes with the number of threads: - -.. code-block:: python - - import timeit - runtimes = [] - threads = [1] + [t for t in range(2, 49, 2)] - for t in threads: - torch.set_num_threads(t) - r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100) - runtimes.append(r) - # ... plotting (threads, runtimes) ... - -Running the script on a system with 24 physical CPU cores (Xeon E5-2680, MKL and OpenMP based build) results in the following runtimes: - -.. image:: cpu_threading_runtimes.svg - :width: 75% - -The following considerations should be taken into account when tuning the number of intra- and inter-op threads: - -* When choosing the number of threads one needs to avoid `oversubscription` (using too many threads, leads to performance degradation). For example, in an application that uses a large application thread pool or heavily relies on - inter-op parallelism, one might find disabling intra-op parallelism as a possible option (i.e. by calling ``set_num_threads(1)``); - -* In a typical application one might encounter a trade off between `latency` (time spent on processing an inference request) and `throughput` (amount of work done per unit of time). Tuning the number of threads can be a useful - tool to adjust this trade off in one way or another. For example, in latency critical applications one might want to increase the number of intra-op threads to process each request as fast as possible. At the same time, parallel implementations - of ops may add an extra overhead that increases amount work done per single request and thus reduces the overall throughput. - .. warning:: - OpenMP does not guarantee that a single per-process intra-op thread - pool is going to be used in the application. On the contrary, two different application or inter-op - threads may use different OpenMP thread pools for intra-op work. - This might result in a large number of threads used by the application. - Extra care in tuning the number of threads is needed to avoid - oversubscription in multi-threaded applications in OpenMP case. - -.. note:: - Pre-built PyTorch releases are compiled with OpenMP support. - -.. note:: - ``parallel_info`` utility prints information about thread settings and can be used for debugging. - Similar output can be also obtained in Python with ``torch.__config__.parallel_info()`` call. - -.. _OpenMP: https://www.openmp.org/ -.. _TBB: https://github.com/intel/tbb -.. _MKL: https://software.intel.com/en-us/mkl -.. _MKL-DNN: https://github.com/intel/mkl-dnn + TorchScript is deprecated, please use + `torch.export `__ instead. diff --git a/docs/source/notes/cpu_threading_torchscript_inference.svg b/docs/source/notes/cpu_threading_torchscript_inference.svg deleted file mode 100644 index f09884cc5f274..0000000000000 --- a/docs/source/notes/cpu_threading_torchscript_inference.svg +++ /dev/null @@ -1,681 +0,0 @@ - -image/svg+xml -Inputs -Application Thread Pool - -Op -Op -Op -Inference thread -Fork -Op -Join - - -Inter -- -op parallelism -Intra -- -op parallelism - -ATen/Parallel -(e.g. at::parallel_for) - -MKL - -MKL -- -DNN - -... -OpenMP -TBB - - diff --git a/docs/source/notes/large_scale_deployments.rst b/docs/source/notes/large_scale_deployments.rst index 2829ba0e939b2..27380a68cf338 100644 --- a/docs/source/notes/large_scale_deployments.rst +++ b/docs/source/notes/large_scale_deployments.rst @@ -7,9 +7,6 @@ This note talks about several extension points and tricks that might be useful when running PyTorch within a larger system or operating multiple systems using PyTorch in a larger organization. -It doesn't cover topics of deploying models to production. Check -:mod:`torch.jit` or one of the corresponding tutorials. - The note assumes that you either build PyTorch from source in your organization or have an ability to statically link additional code to be loaded when PyTorch is used. Therefore, many of the hooks are exposed as C++ APIs that @@ -86,8 +83,7 @@ scripts, the callback fires only once for a given process for each of the APIs. ``c10::SetAPIUsageHandler`` can be used to register API usage instrumentation handler. Passed argument is going to be an "api key" identifying used point, for -example ``python.import`` for PyTorch extension import or -``torch.script.compile`` if TorchScript compilation was triggered. +example ``python.import`` for PyTorch extension import. .. code-block:: cpp @@ -99,42 +95,6 @@ Note for developers: new API trigger points can be added in code with ``C10_LOG_API_USAGE_ONCE("my_api")`` in C++ or ``torch._C._log_api_usage_once("my.api")`` in Python. -Attaching metadata to saved TorchScript models -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript modules can be saved as an archive file that bundles serialized -parameters and module code as TorchScript (see :meth:`torch.jit.save`). It's -often convenient to bundle additional information together with the model, for -example, description of model producer or auxiliary artifacts. - -It can be achieved by passing the ``_extra_files`` argument to -:meth:`torch.jit.save` and ``torch::jit::load`` to store and retrieve -arbitrary binary blobs during saving process. Since TorchScript files are -regular ZIP archives, extra information gets stored as regular files inside -archive's ``extra/`` directory. - -There's also a global hook allowing to attach extra files to any TorchScript -archive produced in the current process. It might be useful to tag models with -producer metadata, akin to JPEG metadata produced by digital cameras. Example -usage might look like: - -.. code-block:: cpp - - SetExportModuleExtraFilesHook([](const Module&) { - ExtraFilesMap files; - files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}"; - return files; - }); - - -Build environment considerations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript's compilation needs to have access to the original python files as -it uses python's ``inspect.getsource`` call. In certain production environments -it might require explicitly deploying ``.py`` files along with precompiled -``.pyc``. - Common extension points ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/package.md b/docs/source/package.md index e337fedde3e6b..1b50f743d5793 100644 --- a/docs/source/package.md +++ b/docs/source/package.md @@ -416,21 +416,6 @@ with PackageExporter(f2, importer=(importer, sys_importer)) as exporter: exporter.save_pickle("model", "model.pkl", obj) ``` -### Package a TorchScript module? -To package a TorchScript model, use the same `save_pickle` and `load_pickle` APIs as you would with any other object. -Saving TorchScript objects that are attributes or submodules is supported as well with no extra work. - -```python -# save TorchScript just like any other object -with PackageExporter(file_name) as e: - e.save_pickle("res", "script_model.pkl", scripted_model) - e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule) -# load as normal -importer = PackageImporter(file_name) -loaded_script = importer.load_pickle("res", "script_model.pkl") -loaded_mixed = importer.load_pickle("res", "mixed_model.pkl" -``` - ## Explanation ### `torch.package` Format Overview From ee0992871c99fc6a1e19eb839ab65391a168d2f8 Mon Sep 17 00:00:00 2001 From: dsashidh Date: Tue, 15 Jul 2025 21:17:20 +0000 Subject: [PATCH 080/457] Add test for user-managed weights with load_state_dict (#157496) Summary: Adds a unit test to verify that when 'user_managed=True' is passed to 'update_constant_buffer', the compiled AOTI model properly shares parameter storage with the eager model. The test specifically covers the following: 1. Passes model weights to the AOTI model with 'user_managed=True''. 2. Updates the eager model weights using 'load_state_dict()', which performs in-place 3. Asserts that the compiled AOTI model reflects the updated weights, confirming shared memory behavior. Fixes: #157474 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157496 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 08799fd6db708..2ffcc0ead4954 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -5876,6 +5876,31 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) + new_weights = { + "L__self___weight": torch.randn(N, K, device=self.device), + "L__self___bias": torch.randn(N, device=self.device), + } + + runner.update_constant_buffer(new_weights, True, False, True) + runner.swap_constant_buffer() + + model.weight = torch.nn.Parameter(new_weights["L__self___weight"]) + model.bias = torch.nn.Parameter(new_weights["L__self___bias"]) + + updated_state_dict = { + "weight": torch.ones_like(model.weight), + "bias": torch.zeros_like(model.bias), + } + + model.load_state_dict(updated_state_dict) + + new_output = runner_call(test_inputs) + expected_output = model(test_inputs) + torch.testing.assert_close(new_output, expected_output) + + with self.assertRaises(AssertionError): + torch.testing.assert_close(new_expected, new_output) + def test_cond_share_predicte(self): class Model(torch.nn.Module): def forward(self, predicate, x): From 144965ca9af478515736665b0577cded22fa692e Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Tue, 15 Jul 2025 22:04:08 +0000 Subject: [PATCH 081/457] [BE][S538760] get rid of TORCH_CHECK_.* and CHECK macros (#158269) Summary: check will be crit, causing program to exit, which is quite dangerous Test Plan: CI Rollback Plan: Differential Revision: D78050595 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158269 Approved by: https://github.com/SherlockNoMad, https://github.com/henryoier --- test/cpp/nativert/test_itree.cpp | 16 +++--- torch/nativert/detail/ITree.cpp | 30 +++++----- torch/nativert/executor/ConstantFolder.cpp | 10 ++-- torch/nativert/executor/ExecutionFrame.h | 8 ++- torch/nativert/executor/Executor.cpp | 56 ++++++++++++------- torch/nativert/executor/GraphExecutorBase.cpp | 4 +- .../executor/memory/AliasAnalyzer.cpp | 4 +- .../nativert/executor/memory/AliasAnalyzer.h | 2 +- .../executor/memory/LayoutManager.cpp | 14 ++--- .../nativert/executor/memory/LayoutManager.h | 12 ++-- .../executor/memory/LayoutPlanner.cpp | 6 +- .../nativert/executor/memory/LayoutPlanner.h | 2 +- torch/nativert/graph/Graph.cpp | 38 ++++++------- torch/nativert/graph/GraphPasses.cpp | 2 +- torch/nativert/graph/Serialization.cpp | 8 ++- torch/nativert/graph/TensorMeta.cpp | 2 +- torch/nativert/graph/TensorMeta.h | 6 +- torch/nativert/kernels/C10Kernel.cpp | 6 +- .../nativert/kernels/CallTorchBindKernel.cpp | 16 +++--- torch/nativert/kernels/HigherOrderKernel.cpp | 14 ++--- torch/nativert/kernels/KernelFactory.cpp | 10 ++-- 21 files changed, 147 insertions(+), 119 deletions(-) diff --git a/test/cpp/nativert/test_itree.cpp b/test/cpp/nativert/test_itree.cpp index e0004f7db77e4..4748c11c3e17a 100644 --- a/test/cpp/nativert/test_itree.cpp +++ b/test/cpp/nativert/test_itree.cpp @@ -259,7 +259,7 @@ TEST(ITreeTest, NoContext) { c10::IValue(8), c10::IValue(9), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, TooManyContext) { @@ -304,7 +304,7 @@ TEST(ITreeTest, TooManyContext) { c10::IValue(8), c10::IValue(9), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, DoubleRegister) { @@ -375,7 +375,7 @@ TEST(ITreeTest, NotEnoughUnflatten) { c10::IValue(2), c10::IValue(7), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, TooManyUnflatten) { @@ -449,7 +449,7 @@ TEST(ITreeTest, TooManyUnflatten) { c10::IValue(2), c10::IValue(7), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, Flatten) { @@ -908,8 +908,8 @@ TEST(ITreeTest, UnmatchedDictFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); - ASSERT_DEATH( - { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); + EXPECT_THROW( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); } TEST(ITreeTest, DictFlattenTest) { @@ -1025,8 +1025,8 @@ TEST(ITreeTest, UnmatchedTupleFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); - ASSERT_DEATH( - { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); + EXPECT_THROW( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); } TEST(ITreeTest, ToAtenType) { diff --git a/torch/nativert/detail/ITree.cpp b/torch/nativert/detail/ITree.cpp index 123ee4498d06f..cd24ca78320fb 100644 --- a/torch/nativert/detail/ITree.cpp +++ b/torch/nativert/detail/ITree.cpp @@ -46,7 +46,7 @@ class PytreeNodeRegistry { const ITreeSpec& spec, std::vector& ivalues) { const auto& tuple = nested.toTupleRef().elements(); - TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { itreeFlatten(tuple[i], spec.children(i), ivalues); } @@ -60,7 +60,7 @@ class PytreeNodeRegistry { const c10::IValue& nested, const ITreeSpec& spec) { const auto& tuple = nested.toTupleRef().elements(); - TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { ivalueApply(fn, tuple[i], spec.children(i)); } @@ -119,7 +119,7 @@ class PytreeNodeRegistry { const auto& contextKeys = spec.contextKeys(); // allow the dict size less than the spec, missing key will be // filled with empty tensor - TORCH_CHECK_LE(dict.size(), contextKeys.size()); + TORCH_CHECK(dict.size() <= contextKeys.size()); size_t i = 0; for (const auto& key : contextKeys) { auto it = dict.find(key); @@ -143,7 +143,7 @@ class PytreeNodeRegistry { c10::Dict dict( c10::AnyType::get(), c10::AnyType::get()); TORCH_CHECK(obj.is_array()); - TORCH_CHECK_EQ(obj.size(), flats.size()); + TORCH_CHECK(obj.size() == flats.size()); dict.reserve(flats.size()); for (size_t i = 0; i < flats.size(); i++) { dict.insert(dynamicToIValue(obj[i]), std::move(flats[i])); @@ -200,7 +200,7 @@ ITreeSpec makeITreeSpec( TORCH_CHECK(obj.is_object()); TORCH_CHECK(obj.find("type") != obj.end()); if (obj["type"].is_null()) { - TORCH_CHECK_EQ(obj["children_spec"].size(), 0); + TORCH_CHECK(obj["children_spec"].empty()); TORCH_CHECK(obj["context"].is_null()); const Value* value = values[start]; @@ -244,11 +244,11 @@ ITreeSpec itreeSpecLoads( const std::vector& values) { const auto obj = nlohmann::json::parse(json); TORCH_CHECK(obj.is_array()); - TORCH_CHECK_EQ(obj.size(), 2); - TORCH_CHECK_EQ(obj[0].get(), kDefaultTreeSpecSerializationProtocol); + TORCH_CHECK(obj.size() == 2); + TORCH_CHECK(obj[0].get() == kDefaultTreeSpecSerializationProtocol); auto result = makeITreeSpec(obj[1], values, 0); - TORCH_CHECK_EQ(result.numIValues(), values.size()); + TORCH_CHECK(result.numIValues() == values.size()); return result; } @@ -256,7 +256,7 @@ c10::IValue itreeUnflatten( std::vector ivalues, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeUnflatten"); - TORCH_CHECK_EQ(ivalues.size(), spec.numIValues()); + TORCH_CHECK(ivalues.size() == spec.numIValues()); if (spec.isIValue()) { return std::move(ivalues[0]); } @@ -299,20 +299,20 @@ std::vector itreeFlattenFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs"); TORCH_CHECK(!spec.isIValue()); - TORCH_CHECK_EQ(spec.children().size(), 2); + TORCH_CHECK(spec.children().size() == 2); std::vector ivalues; ivalues.reserve(spec.numIValues()); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); - TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { itreeFlatten(args[i], specArgs.children(i), ivalues); } const auto& specKwargs = spec.children(1); TORCH_CHECK(!specKwargs.isIValue()); - TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size()); + TORCH_CHECK(specKwargs.context().size() == kwargs.size()); for (size_t i = 0; i < specKwargs.context().size(); i++) { itreeFlatten( kwargs.at(specKwargs.context()[i].get_ref()), @@ -329,11 +329,11 @@ void ivalueApplyFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs"); TORCH_CHECK(!spec.isIValue()); - TORCH_CHECK_EQ(spec.children().size(), 2); + TORCH_CHECK(spec.children().size() == 2); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); - TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { ivalueApply(fn, args[i], specArgs.children(i)); } @@ -342,7 +342,7 @@ void ivalueApplyFromArgs( TORCH_CHECK(!specKwargs.isIValue()); const auto& ctx = specKwargs.context(); - TORCH_CHECK_EQ(ctx.size(), kwargs.size()); + TORCH_CHECK(ctx.size() == kwargs.size()); for (size_t i = 0; i < ctx.size(); i++) { ivalueApply( diff --git a/torch/nativert/executor/ConstantFolder.cpp b/torch/nativert/executor/ConstantFolder.cpp index 7db1fd736243f..13d253394805b 100644 --- a/torch/nativert/executor/ConstantFolder.cpp +++ b/torch/nativert/executor/ConstantFolder.cpp @@ -24,8 +24,9 @@ namespace torch::nativert { void ConstantFolder::unlinkConstants( std::vector>& kernels) { - TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size()) - << "graph node count and kernel count should be equal"; + TORCH_CHECK( + kernels.size() == graph_.nodes().size(), + "graph node count and kernel count should be equal"); unlinked_ = true; @@ -135,8 +136,9 @@ void ConstantFolder::unlinkConstants( */ void ConstantFolder::evaluate(Weights& weights) { - CHECK(unlinked_) - << "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"; + TORCH_CHECK( + unlinked_, + "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"); weights.validateAllWeightsLoaded(); diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h index 945c3b0c5036d..4cf02054bc5e4 100644 --- a/torch/nativert/executor/ExecutionFrame.h +++ b/torch/nativert/executor/ExecutionFrame.h @@ -124,8 +124,10 @@ class ExecutionFrame { } c10::intrusive_ptr getWork(int64_t workId) const { - CHECK(work_.find(workId) != work_.end()) - << "Couldn't find work with Id: " << workId; + TORCH_CHECK( + work_.find(workId) != work_.end(), + "Couldn't find work with Id: ", + workId); return work_.at(workId); } @@ -151,7 +153,7 @@ class ExecutionFrame { private: bool isOutputMovable(size_t idx) const { - TORCH_CHECK_LT(idx, moveable_output_mask_.size()); + TORCH_CHECK(idx < moveable_output_mask_.size()); return moveable_output_mask_[idx]; } diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index eb25342f65df2..a90b93bd17c7d 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -147,15 +147,23 @@ void validateInput( const std::string& inputName, const at::Tensor& inputTensor, const torch::nativert::TensorMeta& tensorValueMeta) { - CHECK(inputTensor.dtype() == tensorValueMeta.dtype()) - << "Input tensor dtype mismatch for " << inputName << ", expecting " - << c10::toString(tensorValueMeta.dtype()) << " but got " - << inputTensor.dtype().name(); - - CHECK(inputTensor.device() == tensorValueMeta.device()) - << "Input tensor device mismatch for " << inputName << ", expecting " - << tensorValueMeta.device().str() << " but got " - << inputTensor.device().str(); + TORCH_CHECK( + inputTensor.dtype() == tensorValueMeta.dtype(), + "Input tensor dtype mismatch for ", + inputName, + ", expecting ", + c10::toString(tensorValueMeta.dtype()), + " but got ", + inputTensor.dtype().name()); + + TORCH_CHECK( + inputTensor.device() == tensorValueMeta.device(), + "Input tensor device mismatch for ", + inputName, + ", expecting ", + tensorValueMeta.device().str(), + " but got ", + inputTensor.device().str()); } } // namespace @@ -169,8 +177,11 @@ void Executor::validateInputs(const std::vector& inputs) const { if (actualInput.isTensor()) { const auto& inputName = std::string(inputValues[i]->name()); auto it = tensorValuesMeta.find(inputName); - CHECK(it != tensorValuesMeta.end()) - << "Couldn't find " << inputName << " in tensorValuesMeta"; + TORCH_CHECK( + it != tensorValuesMeta.end(), + "Couldn't find ", + inputName, + " in tensorValuesMeta"); validateInput(inputName, actualInput.toTensor(), it->second); } } @@ -291,15 +302,17 @@ void Executor::returnExecutorFrameToPool( // Create an entry with used=true if (C10_UNLIKELY(!clearingInProgress_)) { - CHECK(executionFrames_.writeIfNotFull(std::move(frame))) - << "ExecutionFrame pool full"; + TORCH_CHECK( + executionFrames_.writeIfNotFull(std::move(frame)), + "ExecutionFrame pool full"); } else { ExecutionFrameEntry frameEntry; frameEntry.used = true; frameEntry.frame = std::move(frame); - CHECK(clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry))) - << "Cleared ExecutionFrame pool full"; + TORCH_CHECK( + clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)), + "Cleared ExecutionFrame pool full"); } } catch (...) { sem_.release(); @@ -326,7 +339,7 @@ std::vector Executor::execute( std::optional> outputs; const auto userInputs = graph_->userInputs(); const auto& tensorValuesMeta = graph_->tensorValuesMeta(); - TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numIValues()); + TORCH_CHECK(userInputs.size() == inputTreeSpec.numIValues()); auto executionFrameFillUserInputs = [&](const c10::IValue& leaf, const Value* value) { @@ -334,8 +347,11 @@ std::vector Executor::execute( if (executorConfig_.validateInputs && leaf.isTensor()) { const auto& inputName = std::string(value->name()); auto it = tensorValuesMeta.find(inputName); - CHECK(it != tensorValuesMeta.end()) - << "Couldn't find " << inputName << " in tensorValuesMeta"; + TORCH_CHECK( + it != tensorValuesMeta.end(), + "Couldn't find ", + inputName, + " in tensorValuesMeta"); validateInput(inputName, leaf.toTensor(), it->second); } executionFrame->setBorrowedIValue( @@ -357,8 +373,8 @@ ProfileMetrics Executor::benchmarkIndividualNodes( const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns) { - CHECK(!inputsList.empty()) << "Need at least one input to benchmark"; - CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run"; + TORCH_CHECK(!inputsList.empty(), "Need at least one input to benchmark"); + TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1, "Need at least one run"); for (const auto& inputs : inputsList) { if (executorConfig_.validateInputs) { diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index 7796575aad291..1c85e27253169 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -20,7 +20,7 @@ void GraphExecutorBase::fillUserInputs( std::vector inputs) { RECORD_USER_SCOPE("Executor::fillUserInputs"); const auto& inputValues = graph_.userInputs(); - TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + TORCH_CHECK(inputValues.size() == inputs.size()); // load user input tensor into execution frame for (size_t i = 0; i < inputValues.size(); i++) { @@ -78,7 +78,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( for (auto inputs : inputsList) { const auto& inputValues = graph_.userInputs(); - TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + TORCH_CHECK(inputValues.size() == inputs.size()); for (size_t j = 0; j < inputValues.size(); j++) { executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); } diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp index 0bef32545d14b..86de7bc3d6fb6 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.cpp +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -66,13 +66,13 @@ bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( } const auto& list_elems = list->getListElements(); - TORCH_CHECK_EQ(list_elems.size(), node.numOutputs()); + TORCH_CHECK(list_elems.size() == node.numOutputs()); for (const auto j : c10::irange(node.numOutputs())) { const Value* input = list_elems.at(j); const Value* output = node.outputs().at(j); - TORCH_CHECK_NE(input, output); + TORCH_CHECK(input != output); create_or_update_lifetime(input, i); create_or_update_lifetime(output, i); diff --git a/torch/nativert/executor/memory/AliasAnalyzer.h b/torch/nativert/executor/memory/AliasAnalyzer.h index 4fd3b1261b3d7..4b0d827453b0f 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.h +++ b/torch/nativert/executor/memory/AliasAnalyzer.h @@ -42,7 +42,7 @@ class AliasAnalyzer { } const std::vector& alive_values_at_time(size_t time) const { - TORCH_CHECK_LT(time, alive_values_at_time_.size()); + TORCH_CHECK(time < alive_values_at_time_.size()); return alive_values_at_time_[time]; } diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index a75070095caf7..827e8cd057817 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -140,8 +140,8 @@ void LayoutManager::ensure_managed_storages(bool allocate) { } void LayoutManager::populate_tensor_values() { - CHECK(planned_tensors_.empty()); - CHECK(unplanned_ivalues_.empty()); + TORCH_CHECK(planned_tensors_.empty()); + TORCH_CHECK(unplanned_ivalues_.empty()); const auto& value_ids = planner_.get_planned_values(); planned_tensors_.resize(value_ids.size()); @@ -222,8 +222,8 @@ void LayoutManager::assert_no_overlapping_storages( return; } auto& alloc = plan.allocations[value_to_vector_idx_map_.at(value_id)]; - TORCH_CHECK_GE(alloc_start, alloc.offset); - TORCH_CHECK_LT(alloc_end, alloc.offset + alloc.size); + TORCH_CHECK(alloc_start >= alloc.offset); + TORCH_CHECK(alloc_end < alloc.offset + alloc.size); intervals.emplace(alloc_start, alloc_end); }; @@ -254,8 +254,8 @@ void LayoutManager::assert_no_overlapping_storages( // sanity check lifetimes to ensure this // value ~should~ be alive at this point const auto& lt = alias_analyzer.lifetime(v); - TORCH_CHECK_GE(graph_node_idx, lt.start); - TORCH_CHECK_LE(graph_node_idx, lt.end); + TORCH_CHECK(graph_node_idx >= lt.start); + TORCH_CHECK(graph_node_idx <= lt.end); const auto interval = try_get_interval(v->id()); if (C10_UNLIKELY(!interval.has_value())) { @@ -314,7 +314,7 @@ void LayoutManager::assert_no_overlapping_storages( auto it = intervals.begin(); size_t prev_end = it->second; while (++it != intervals.end()) { - TORCH_CHECK_LT(prev_end, it->first /* cur_start */); + TORCH_CHECK(prev_end < it->first /* cur_start */); prev_end = it->second; } } diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h index 347c51fe2edec..d98700e7f0215 100644 --- a/torch/nativert/executor/memory/LayoutManager.h +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -40,8 +40,8 @@ struct ContiguousLayoutBuffer { void* get_ptr_with_offset(size_t offset) { void* raw_ptr = data_ptr_.get(); - TORCH_CHECK_NOTNULL(raw_ptr); - TORCH_CHECK_LE(offset, size_); + TORCH_CHECK(raw_ptr != nullptr); + TORCH_CHECK(offset <= size_); return reinterpret_cast( reinterpret_cast(raw_ptr) + offset); } @@ -61,7 +61,7 @@ struct ContiguousLayoutBuffer { void clear(size_t size) { VLOG(1) << "clearing first " << size << "bytes of layout buffer of size " << size_; - TORCH_CHECK_LE(size, size_); + TORCH_CHECK(size <= size_); std::memset(data_ptr_.get(), 0, size); } @@ -126,8 +126,8 @@ struct ContiguousStorageImplBuffer { } c10::StorageImpl& at(size_t i) { - TORCH_CHECK_LT(i, size_) - << "requested storage index " << i << " out of bounds " << size_; + TORCH_CHECK( + i < size_, "requested storage index ", i, " out of bounds ", size_); return buffer_[i]; } @@ -138,7 +138,7 @@ struct ContiguousStorageImplBuffer { } c10::StorageImpl& to_managed(at::StorageImpl& s) { - TORCH_CHECK_LT(size_, capacity_); + TORCH_CHECK(size_ < capacity_); return *(new (&buffer_[size_++]) at::StorageImpl( at::StorageImpl::use_byte_size_t(), static_cast(s.nbytes()), diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp index 5fb0b8fced6f7..ead887bbe4700 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.cpp +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -133,7 +133,7 @@ LayoutPlanner::LayoutPlanner( } } - TORCH_CHECK_NOTNULL(algorithm_); + TORCH_CHECK(algorithm_ != nullptr, "algorithm can't be null"); initialize_vectors(value_to_allocation_spec); @@ -159,7 +159,9 @@ void LayoutPlanner::initialize_vectors( size_t i = 0; for (auto& [v, spec] : value_to_allocation_spec) { - TORCH_CHECK_LE(spec.lifetime.start, spec.lifetime.end); + TORCH_CHECK( + spec.lifetime.start <= spec.lifetime.end, + "lifetime start must be before lifetime end"); planned_values_[i] = v->id(); planned_values_historical_max_nbytes_[i] = spec.size; diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h index 83a2386c6dacf..10dcf906bef3e 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.h +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -73,7 +73,7 @@ class LayoutPlanner { } bool is_managed(ValueId id) { - TORCH_CHECK_LT(static_cast(id), managed_values_.size()); + TORCH_CHECK(static_cast(id) < managed_values_.size()); return managed_values_[id]; } diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index 3cc7f678fcff0..bce01f278a572 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -568,7 +568,7 @@ void Graph::lint() const { } } for (const auto& node : nodes()) { - TORCH_CHECK_EQ(node.owningGraph(), this); + TORCH_CHECK(node.owningGraph() == this); } // Check that every list type is either produced by a prim.ListPack or // immediately consumed by a prim.ListUnpack. We make use of this invariant @@ -668,7 +668,7 @@ void Graph::applyDevicePlacement(const Placement& placement) { } Node* Graph::nodeAfter(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } @@ -677,7 +677,7 @@ Node* Graph::nodeAfter(Node* n) { } const Node* Graph::nodeAfter(const Node* n) const { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } @@ -686,7 +686,7 @@ const Node* Graph::nodeAfter(const Node* n) const { } Node* Graph::nodeBefore(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } @@ -695,7 +695,7 @@ Node* Graph::nodeBefore(Node* n) { } const Node* Graph::nodeBefore(const Node* n) const { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } @@ -704,8 +704,7 @@ const Node* Graph::nodeBefore(const Node* n) const { } void Graph::removeNode(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this) - << "Node does not belong to this graph!"; + TORCH_CHECK(n->owningGraph() == this, "Node does not belong to this graph!"); for (auto* outputVal : n->outputs()) { TORCH_CHECK( @@ -747,8 +746,7 @@ std::vector Graph::insertGraph( const Graph& subgraph, std::vector inputs, std::unordered_map& valueMap) { - TORCH_CHECK_EQ(subgraph.inputs().size(), inputs.size()) - << "Input size mismatch"; + TORCH_CHECK(subgraph.inputs().size() == inputs.size(), "Input size mismatch"); for (auto i : c10::irange(subgraph.inputs().size())) { valueMap[subgraph.inputs()[i]] = inputs[i]; } @@ -854,7 +852,7 @@ void Node::addOutput() { } Value* Node::addOutput(const Type& type) { - TORCH_CHECK_EQ(type, Type::Kind::None); + TORCH_CHECK(type == Type::Kind::None); Value* v = owningGraph_->addValue(std::nullopt, type, this); outputs_.push_back(v); return v; @@ -893,9 +891,9 @@ std::vector Value::getListElements() const { ret.push_back(tv.value); } } else { - TORCH_CHECK_EQ(users().size(), 1); + TORCH_CHECK(users().size() == 1); const auto listUnpack = users()[0]; - TORCH_CHECK_EQ(listUnpack->target(), "prim.ListUnpack"); + TORCH_CHECK(listUnpack->target() == "prim.ListUnpack"); for (const auto v : listUnpack->outputs()) { ret.push_back(v); } @@ -1070,17 +1068,17 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) { c10::Device convertDevice(std::string_view symbol) { // Symbol looks like `Device{cuda:1}` const auto typeStart = symbol.find('{') + 1; - TORCH_CHECK_LT(typeStart, symbol.size()); + TORCH_CHECK(typeStart < symbol.size()); const auto typeEnd = symbol.find(':'); - TORCH_CHECK_NE(typeEnd, std::string_view::npos); + TORCH_CHECK(typeEnd != std::string_view::npos); const auto type = symbol.substr(typeStart, typeEnd - typeStart); const auto indexStart = typeEnd + 1; - TORCH_CHECK_LT(indexStart, symbol.size()); + TORCH_CHECK(indexStart < symbol.size()); const auto indexEnd = symbol.find('}'); - TORCH_CHECK_NE(indexEnd, std::string_view::npos); + TORCH_CHECK(indexEnd != std::string_view::npos); const auto index = symbol.substr(indexStart, indexEnd - indexStart); @@ -1099,7 +1097,7 @@ c10::Device convertDevice(std::string_view symbol) { Constant convertAtomicConstant(std::string_view symbol) { if (c10::starts_with(symbol, "\"")) { // chop off the outer quotes and return the string - TORCH_CHECK_GE(symbol.size(), 2); + TORCH_CHECK(symbol.size() >= 2); symbol.remove_prefix(1); symbol.remove_suffix(1); return std::string(symbol); @@ -1178,8 +1176,8 @@ Constant convertListConstant(std::string_view source) { TORCH_CHECK(false, "constant lists only support int, float, bool"); } } else { - TORCH_CHECK_EQ(type.index(), val.index()) - << "lists must have all the same type"; + TORCH_CHECK( + type.index() == val.index(), "lists must have all the same type"); } values.push_back(std::move(val)); if (source.at(curPos) == ']') { @@ -1306,7 +1304,7 @@ bool Parser::nextIf(char expected) { } void Parser::parseGraphInputs() { - TORCH_CHECK_EQ(curPos_, 0); + TORCH_CHECK(curPos_ == 0); expect("graph"); const auto inputs = parseList( '(', ')', [&]() { return parseAtomicSymbol(); }); diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp index 327f32185e910..981a63815db2f 100644 --- a/torch/nativert/graph/GraphPasses.cpp +++ b/torch/nativert/graph/GraphPasses.cpp @@ -101,7 +101,7 @@ std::string selectScalarOverloadName(const Node& node) { "floor_divide_out", "_conj"}; std::vector atoms = c10::split(node.target(), '.'); - TORCH_CHECK_GE(atoms.size(), 3); + TORCH_CHECK(atoms.size() >= 3); std::string ns = std::string{atoms[atoms.size() - 3]}; std::string opName = std::string{atoms[atoms.size() - 2]}; diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index d32e7fe728436..4c45edd1f5751 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -422,9 +422,11 @@ std::unique_ptr jsonToSubgraph( } auto it = jsonTensorValue.find(inputName); - CHECK(it != jsonTensorValue.end()) - << "Missing tensor metadata for " << inputName - << "in thriftGraph.tensorValue"; + TORCH_CHECK( + it != jsonTensorValue.end(), + "Missing tensor metadata for ", + inputName, + "in thriftGraph.tensorValue"); weightsTensorMeta[weightName] = it->second; } graph->setWeightsMeta(weightsTensorMeta); diff --git a/torch/nativert/graph/TensorMeta.cpp b/torch/nativert/graph/TensorMeta.cpp index d7d83710a5a35..81625dca116f9 100644 --- a/torch/nativert/graph/TensorMeta.cpp +++ b/torch/nativert/graph/TensorMeta.cpp @@ -106,7 +106,7 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) torch::_export::SymInt::Tag::AS_INT) { storage_offset_ = tensorMeta.get_storage_offset().get_as_int(); } else { - CHECK(false) << "SymInt not supported yet"; + TORCH_CHECK(false, "SymInt not supported yet"); } for (const auto& size : tensorMeta.get_sizes()) { diff --git a/torch/nativert/graph/TensorMeta.h b/torch/nativert/graph/TensorMeta.h index 5b0c90474a097..585383a95b5fd 100644 --- a/torch/nativert/graph/TensorMeta.h +++ b/torch/nativert/graph/TensorMeta.h @@ -25,12 +25,12 @@ class TensorMeta { explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta); c10::IntArrayRef sizes() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return sizes_; } c10::IntArrayRef strides() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return strides_; } @@ -55,7 +55,7 @@ class TensorMeta { } int64_t numel() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return numel_; } diff --git a/torch/nativert/kernels/C10Kernel.cpp b/torch/nativert/kernels/C10Kernel.cpp index 450042e7c92d3..3c207e5708a39 100644 --- a/torch/nativert/kernels/C10Kernel.cpp +++ b/torch/nativert/kernels/C10Kernel.cpp @@ -49,8 +49,10 @@ void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const { // these are named I don't think it will ever happen in practice. We need to // enforce it though. const auto& outputValues = node_->outputs(); - TORCH_CHECK_EQ(outputValues.size(), stack.size()) - << "Output size mismatch for " << node_->toString(); + TORCH_CHECK( + outputValues.size() == stack.size(), + "Output size mismatch for ", + node_->toString()); for (auto&& [i, actualOutput] : c10::enumerate(stack)) { executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput)); } diff --git a/torch/nativert/kernels/CallTorchBindKernel.cpp b/torch/nativert/kernels/CallTorchBindKernel.cpp index 5e8c9cf6be759..c3643cbce1da5 100644 --- a/torch/nativert/kernels/CallTorchBindKernel.cpp +++ b/torch/nativert/kernels/CallTorchBindKernel.cpp @@ -8,7 +8,7 @@ namespace torch::nativert { CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) { const Value* customObjValue = node_->inputs()[0].value; - CHECK(customObjValue->type() == Type::Kind::CustomObj); + TORCH_CHECK(customObjValue->type() == Type::Kind::CustomObj); customClassName_ = customObjValue->type().classFqn(); customClassType_ = torch::jit::getCustomClass(customClassName_); @@ -16,16 +16,18 @@ CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) { // sample schema // torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1); - CHECK(node->attributes().size() == 1) - << "Expects higher_order.call_torchbind to only have a single attribute, methodName"; + TORCH_CHECK( + node->attributes().size() == 1, + "Expects higher_order.call_torchbind to only have a single attribute, methodName"); const auto& attr = node->attributes()[0]; - CHECK(std::holds_alternative(attr.value)) - << "method should be a string"; + TORCH_CHECK( + std::holds_alternative(attr.value), + "method should be a string"); methodName_ = std::get(attr.value); method_ = customClassType_->findMethod(methodName_); - CHECK(method_ != nullptr) << "method not found: " << methodName_; + TORCH_CHECK(method_ != nullptr, "method not found: ", methodName_); } void CallTorchBindKernel::computeInternal( @@ -42,7 +44,7 @@ void CallTorchBindKernel::computeInternal( // set outputs const auto& outputs = node_->outputs(); - TORCH_CHECK_EQ(outputs.size(), stack.size()); + TORCH_CHECK(outputs.size() == stack.size()); for (auto&& [i, outputValue] : c10::enumerate(stack)) { executionFrame.setIValue(outputs[i]->id(), std::move(outputValue)); } diff --git a/torch/nativert/kernels/HigherOrderKernel.cpp b/torch/nativert/kernels/HigherOrderKernel.cpp index a1f1393c01882..370339c82f820 100644 --- a/torch/nativert/kernels/HigherOrderKernel.cpp +++ b/torch/nativert/kernels/HigherOrderKernel.cpp @@ -11,28 +11,28 @@ HigherOrderKernel::HigherOrderKernel( std::vector> graphExecutors) : OpKernel(node), graphExecutors_(std::move(graphExecutors)) { static constexpr std::string_view prefix = "torch.ops.higher_order."; - CHECK(c10::starts_with(node->target(), prefix)); + TORCH_CHECK(c10::starts_with(node->target(), prefix)); auto opName = node->target().substr(prefix.size()); if (opName == "cond") { opType_ = OpType::COND; // Checking torch.cond schema is as expected: // torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args) // -> Tensor[] - TORCH_CHECK_EQ(node_->attributes().size(), 2); - TORCH_CHECK_EQ(node_->inputs().size(), 2); + TORCH_CHECK(node_->attributes().size() == 2); + TORCH_CHECK(node_->inputs().size() == 2); } else if (opName == "while_loop") { opType_ = OpType::WHILE_LOOP; // Checking torch.while_loop schema is as expected: // torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[] // additional) -> Tensor[] - TORCH_CHECK_EQ(node_->attributes().size(), 2); - TORCH_CHECK_EQ(node_->inputs().size(), 2); + TORCH_CHECK(node_->attributes().size() == 2); + TORCH_CHECK(node_->inputs().size() == 2); } else if (opName == "run_const_graph") { opType_ = OpType::RUN_CONST_GRAPH; // Checking torch.run_const_graph schema is as expected: // torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[] - TORCH_CHECK_GE(node_->attributes().size(), 1); - TORCH_CHECK_EQ(node_->inputs().size(), 1); + TORCH_CHECK(!node_->attributes().empty()); + TORCH_CHECK(node_->inputs().size() == 1); } else { throw std::runtime_error( fmt::format("Unknown higher order op: {}", opName)); diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 0720c28a7b6a5..db055a6cf220c 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -215,10 +215,12 @@ ExecutionKernels KernelFactory::initializeNodeKernels( const auto& subgraph = std::get>(attr.value); auto executionKernels = initializeNodeKernels( *subgraph, weights, executorConfig, placement); - CHECK(executionKernels.delegateExecutors.empty()) - << "HigherOrderKernel does not support delegates"; - CHECK(executionKernels.constFoldingExecutions.empty()) - << "HigherOrderKernel does not support const folding"; + TORCH_CHECK( + executionKernels.delegateExecutors.empty(), + "HigherOrderKernel does not support delegates"); + TORCH_CHECK( + executionKernels.constFoldingExecutions.empty(), + "HigherOrderKernel does not support const folding"); if (executorConfig.maxParallelOps > 1) { graphExecutors.emplace_back( std::unique_ptr(new ParallelGraphExecutor( From 651b4a68f2a60d55d266e40776709247ef347d68 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 15 Jul 2025 11:22:51 -0700 Subject: [PATCH 082/457] [hop][dynamo] track run-ahead sym variables in side effects (#158273) Before the PR, for code like this: ``` class Example2(torch.nn.Module): def forward(self, x, trigger, target): return torch.cond( trigger == 1, lambda: x + target, lambda: x * target, (), ) m = Example2() x = torch.randn(2) trigger = 0 target = 2 args = (x, trigger, target) ep = torch.export.export( m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC) ) ``` dynamo will wrap "target" (i.e. a symInt) twice, once when we speculate the first lambda and find target is a symint and decides to wrap it up, creating a new SymNodeVariable and a placeholder input to the top-level graph. The second time happens when we speculate the second lambda. Tensors are de-duplicated by checking tracked side effects to make sure object with the same id (though different sources) is mapped to the same TensorVaraible. For symints, two things are missing: 1. it's not in the _can_lift_attrs_to_input list (the change in builder.py) 2. it's not in the tracked by runahead_side_effects, so when speculate_subgraph finishes, they're discarded (the change in side_effects.py) Note: the auto lifting mechanism for HOPs happens at proxy level when we trace the subgraph, which is after SymNodeVariable are created (they're created when realizing the args and bind them to subgraph). At that time, builder has created two unique SymNodeVariable for the same symint so the auto lifting in hops cannot de-dup them. Differential Revision: [D78298163](https://our.internmc.facebook.com/intern/diff/D78298163) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158273 Approved by: https://github.com/avikchaudhuri, https://github.com/zou3519 --- test/export/test_export.py | 40 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 4 +-- torch/_dynamo/variables/builder.py | 12 ++++++- torch/_dynamo/variables/higher_order_ops.py | 6 ++-- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 2ded21ec87e06..b772667de105e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -249,6 +249,10 @@ def is_training_ir_test(test_name): ) +def is_training_ir_strict_test(test_name): + return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) + + def is_cpp_runtime_test(test_name): return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith( CPP_RUNTIME_NONSTRICT_SUFFIX @@ -1583,6 +1587,42 @@ def forward(self): ) self.assertEqual(m(*args), ep.module()(*args)) + @testing.expectedFailureCppSerDes # AssertionError: 0 not in VR[2, int_oo] + @testing.expectedFailureSerDer # AssertionError: 0 not in VR[2, int_oo] + @testing.expectedFailureSerDerNonStrict # AssertionError: 0 not in VR[2, int_oo] + def test_cond_access_identical_symint_closure(self): + class Example2(torch.nn.Module): + def forward(self, x, trigger, target): + return torch.cond( + trigger == 1, + lambda: x + target, + lambda: x * target, + (), + ) + + m = Example2() + x = torch.randn(2) + trigger = 0 + target = 2 + args = (x, trigger, target) + ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC)) + if is_training_ir_strict_test(self._testMethodName): + # In strict mode export's result capturing compiler, we create + # 2 new symints when re-fakifying the symint inputs. + # Then in run_decompositions, ep.range_constraints was updated + # where it checks the var_to_range and put the two newly added ones into the range_constraints. + self.assertExpectedInline( + str(tuple(ep.range_constraints.values())), + """(VR[0, int_oo], VR[0, int_oo], VR[-int_oo, int_oo], VR[-int_oo, int_oo])""", + ) + else: + self.assertExpectedInline( + str(tuple(ep.range_constraints.values())), + """(VR[0, int_oo], VR[0, int_oo])""", + ) + + self.assertEqual(m(*args), ep.module()(*args)) + def test_cond_branches_return_same_int(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index a109d11e473de..ab7c7561a88c8 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -536,7 +536,7 @@ def track_save_for_backward(self, ctx, args): assert isinstance(ctx, variables.AutogradFunctionContextVariable) self.save_for_backward.append((ctx, args)) - def track_tensor_variables_from_runahead_side_effects(self, other): + def track_runahead_tensor_and_symvar_side_effects(self, other): # In higher order ops we want to keep track of tensors seen in the # speculate_subgraph so that we don't lift them again as a new input in # other speculate_subgraph or in the root tracer. @@ -544,7 +544,7 @@ def track_tensor_variables_from_runahead_side_effects(self, other): other_id = id(other_item) other_variable = other.id_to_variable[other_id] if other_id not in self.id_to_variable and isinstance( - other_variable, variables.TensorVariable + other_variable, (variables.TensorVariable, variables.SymNodeVariable) ): self.track_object_existing(other_item, other_variable) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8ae5a4bd6cee7..52f2bef5677a4 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -445,8 +445,18 @@ def __call__(self, value): if vt.source is None: vt.source = self.source + def _is_deduplicable_sym_variable(value, vt): + # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we + # should NOT track them. If we use a single SymNodeVariable instance to track them + # across multiple uses, then guards created for one usage will incorrectly apply to + # all other usages of that constant, leading to unnecessary recompilations. + return is_torch_sym(value) and isinstance(vt, SymNodeVariable) + if ( - self._can_lift_attrs_to_inputs(vt) + ( + self._can_lift_attrs_to_inputs(vt) + or _is_deduplicable_sym_variable(value, vt) + ) and value not in self.tx.output.side_effects and not is_wrapper_or_member_descriptor(value) ): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 82dd2eb4caea7..7064d63945ebb 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -708,7 +708,7 @@ def speculate_subgraph( if restore_side_effects: new_side_effects = tx.output.side_effects.clone() - prev_side_effects.track_tensor_variables_from_runahead_side_effects( + prev_side_effects.track_runahead_tensor_and_symvar_side_effects( new_side_effects ) tx.output.side_effects = prev_side_effects @@ -991,7 +991,9 @@ def call_function( f"{operands.python_type()}", ) operands_seq = operands.unpack_var_sequence(tx) - if not only_consist_of(operands, (TensorVariable, ConstantVariable)): + if not only_consist_of( + operands, (TensorVariable, ConstantVariable, SymNodeVariable) + ): unimplemented( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) From 8554c8007ddaa8029e7e01bb1af12f358bf597c2 Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Tue, 15 Jul 2025 13:13:49 -0700 Subject: [PATCH 083/457] [PT2][fusion] ban fusions with large accumulated reads (#157563) **Problem:** Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet ``` total = torch.rand(N, N) for _ in range(r): x = torch.rand(N, N) total = total + x ``` The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like: ``` x_1 = torch.rand(N, N) x_2 = torch.rand(N, N) ... x_r = torch.rand(N, N) total = x_1 + x_2 + ... + x_r ``` Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient. [internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details **Solution:** Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile. * During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated. * During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`. **Results:** For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match. image Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563 Approved by: https://github.com/jansel, https://github.com/mlazos --- .../pr_time_benchmarks/expected_results.csv | 38 ++++++------- test/inductor/test_inplace_padding.py | 2 + test/inductor/test_memory.py | 53 +++++++++++++++++++ test/inductor/test_online_softmax.py | 8 ++- torch/_inductor/choices.py | 4 ++ torch/_inductor/config.py | 1 + torch/_inductor/graph.py | 21 ++++++++ torch/_inductor/ir.py | 11 ++++ torch/_inductor/memory.py | 13 +---- torch/_inductor/scheduler.py | 29 +++++----- 10 files changed, 131 insertions(+), 49 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index edc9d0f73d161..1b86e02b8afda 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,4 +1,4 @@ -add_loop_eager,compile_time_instruction_count,3017000000,0.015 +add_loop_eager,compile_time_instruction_count,2994000000,0.015 @@ -6,15 +6,15 @@ add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 -add_loop_inductor,compile_time_instruction_count,29490000000,0.015 +add_loop_inductor,compile_time_instruction_count,33260000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42900000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,29880000000,0.015 @@ -22,51 +22,51 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17940000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17210000000,0.015 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10980000000,0.2 -update_hint_regression,compile_time_instruction_count,1673000000,0.02 +update_hint_regression,compile_time_instruction_count,1688000000,0.02 -sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 +sum_floordiv_regression,compile_time_instruction_count,992700000,0.015 -symint_sum,compile_time_instruction_count,3166000000,0.015 +symint_sum,compile_time_instruction_count,3187000000,0.015 -symint_sum_loop,compile_time_instruction_count,4202000000,0.015 +symint_sum_loop,compile_time_instruction_count,4225000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2122000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6040000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8894000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1952000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3905000000,0.015 @@ -74,15 +74,15 @@ aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0 -mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4406000000,0.015 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8274000000,0.015 -basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 +basic_NestedModule_eager,compile_time_instruction_count,8193000000,0.015 diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 80cb86ec417d4..0207134dc7013 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -9,6 +9,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck +from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, @@ -209,6 +210,7 @@ def f(x, y): self.assertEqual(num_inplace_padding(), 0) + @serialTest() @requires_cuda_with_enough_memory(2e10) @inductor_config.patch(force_shape_pad=True) def test_linear_and_cel(self): diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index eaff539f7a493..3e23442b38ec7 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -8,6 +8,7 @@ from torch._inductor import config, memory from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_triton_code +from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -306,6 +307,58 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) + @serialTest() + def test_fusion_acc_large_reads(self): + def f(x, y, z): + res = torch.zeros_like(x[0]) + for i in range(4): + temp = torch.matmul(x, y) + z + res = res + temp + return res + + N = 128 + x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + + # CASE 1: no restriction on the amount of accumulation + with config.patch({"realize_acc_reads_size_threshold": float("inf")}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") + .run(code) + ) + + # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) + # at most 12 / 4 = 3 reads can be accumulated during fusion + with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") + .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") + .run(code) + ) + + # CASE 3: no such fusion allowed + with config.patch({"realize_acc_reads_size_threshold": N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf1, arg2_1,") + .check("triton_poi_fused_add_0.run(buf3, arg2_1,") + .check("triton_poi_fused_add_0.run(buf4, buf3,") + .check("triton_poi_fused_add_0.run(buf6, arg2_1,") + .check("triton_poi_fused_add_0.run(buf7, buf6,") + .check("triton_poi_fused_add_0.run(buf9, arg2_1,") + .check("triton_poi_fused_add_0.run(buf10, buf9,") + .run(code) + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 798d86b0dd617..37959c241113f 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -13,6 +13,7 @@ instantiate_parametrized_tests, IS_LINUX, parametrize, + serialTest, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA @@ -77,12 +78,17 @@ def f(x): out, source_codes = run_and_get_code(f, x) return source_codes[0] + @serialTest() def test_codegen_3pass_softmax_due_to_disable(self): - with inductor_config.patch(online_softmax=False): + with inductor_config.patch( + online_softmax=False, + realize_acc_reads_size_threshold=float("inf"), + ): wrapper_code = self.get_softmax_wrapper() self.assertEqual(wrapper_code.count("for r0_offset in"), 3) + @serialTest() @parametrize("V", [2048, 50304]) @parametrize("use_log_softmax", [False, True]) def test_codegen_online_softmax(self, use_log_softmax, V): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b7bab02da5e4b..9096ba6dd0393 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,6 +365,10 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False + if scheduler.fusion_accumulate_large_reads(node1, node2): + WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") + return False + return True @staticmethod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2e189c102db34..5c7a53683db3b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -574,6 +574,7 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 +realize_acc_reads_size_threshold = 3 * (1024**3) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f28..ac299d5b0c2d0 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,6 +123,7 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -485,6 +486,9 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() + # Cache for dep size hints to avoid expensive recomputation + self.dep_size_hint_cache: dict[Dep, int] = {} + def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -570,6 +574,23 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) + def get_dep_size_hint(self, dep: Dep) -> int: + """ + Get the size hint for a dependency with caching to avoid expensive recomputation. + """ + if dep not in self.dep_size_hint_cache: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.dep_size_hint_cache[dep] = res + return self.dep_size_hint_cache[dep] + def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1edbb214ae2ad..d6dd82aa52f2d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7829,6 +7829,10 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): + """ + StorageBox allow in-place mutation of Tensors + """ + def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7878,10 +7882,17 @@ def realize_hint(self) -> None: ): self.realize() + def has_accumulated_enough_reads_by_size(self) -> bool: + return ( + sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) + > config.realize_acc_reads_size_threshold + ) + def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() + or self.has_accumulated_enough_reads_by_size() ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 5601bc4adcee4..d287208419a9f 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,19 +78,8 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ - # this function is copied from torch/_inductor/scheduler.py - # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - return res + return V.graph.get_dep_size_hint(dep) # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5c7a16d25bc64..34f15869085f0 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,15 +2051,12 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ - __dep_size_hint_cache: dict[Dep, int] - def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() - self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3505,6 +3502,17 @@ def _find_single_user_inputs( return True return False + def fusion_accumulate_large_reads( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( + node1.read_writes.writes | node2.read_writes.writes + ) + return ( + sum(self.dep_size_hint(dep) for dep in all_reads) + > config.realize_acc_reads_size_threshold + ) + def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4010,20 +4018,7 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - res = 0 - if dep not in self.__dep_size_hint_cache: - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.__dep_size_hint_cache[dep] = res - else: - res = self.__dep_size_hint_cache[dep] - return res + return V.graph.get_dep_size_hint(dep) def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode From 03852ddc22350eb8b6ed6b61777639ce6080f3dc Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Jul 2025 01:28:46 +0000 Subject: [PATCH 084/457] Revert "[ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903)" This reverts commit 1ea9cde598ead20194dbb6c5cb26e74e36e6ad55. Reverted https://github.com/pytorch/pytorch/pull/156903 on behalf of https://github.com/atalman due to Breaks torchao and torchtitan nightly builds ([comment](https://github.com/pytorch/pytorch/pull/156903#issuecomment-3076423488)) --- .../tensor/experimental/_attention.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index b3a5768f6fc8b..73b53f051421d 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -43,16 +43,6 @@ class _RotateMethod(Enum): aten = torch.ops.aten logger = logging.getLogger(__name__) -_is_hip: bool = hasattr(torch.version, "hip") and torch.version.hip is not None -if _is_hip: - gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName - _is_ck_supported = False - for arch in ["gfx942", "gfx950"]: - if arch in gcn_arch_name: - _is_ck_supported = True - _preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library - _CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"] - class _DispatchMode(Enum): MONKEY_PATCH = auto() @@ -456,14 +446,6 @@ def _templated_ring_attention( is_causal=is_causal_behavior.value, **kwargs, ) - if _is_hip: # See: https://github.com/pytorch/pytorch/issues/156012 - need_scaling = True - # Note: it is possible that CK is seleted but not compiled in the binary. - if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND: - # Unsure about CK's behavior, keep logsumexp untouched - need_scaling = False - if need_scaling: - logsumexp *= 0.6931471805599453 sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest From 900fba4c073b121b6c9ce581ea27e25c13a354e5 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Wed, 16 Jul 2025 01:28:46 +0000 Subject: [PATCH 085/457] Update warning of TF32 (#158209) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/158209 Approved by: https://github.com/jansel --- aten/src/ATen/Context.cpp | 4 +- docs/source/backends.md | 4 +- docs/source/notes/cuda.rst | 81 +++++++++++++----------- docs/source/notes/mkldnn.rst | 2 +- docs/source/notes/numerical_accuracy.rst | 4 +- 5 files changed, 51 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index cac0e31eaad46..08a834e0a8d4a 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -76,7 +76,9 @@ void check_fp32_prec_backend_and_op( C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ TORCH_WARN_ONCE( - "This API is going to be deprecated, please see " + "Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' " + "or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, " + "torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see " "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" ); } diff --git a/docs/source/backends.md b/docs/source/backends.md index 41869ba9b77b5..6b8cc8bd70724 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -54,7 +54,7 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix - multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. + multiplications on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. ``` ```{eval-rst} @@ -193,7 +193,7 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN - convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. + convolutions on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. ``` ```{eval-rst} diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 98e1d8141dd95..5210eb4ad1495 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -64,6 +64,49 @@ Below you can find a small example showcasing this:: TensorFloat-32 (TF32) on Ampere (and later) devices --------------------------------------------------- +After Pytorch 2.9, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way, and +suggest to use the new APIs for better control. +We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. + +.. code:: python + + torch.backends.fp32_precision = "ieee" + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "tf32" + torch.backends.cudnn.rnn.fp32_precision = "tf32" + +The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`. +`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision. +`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision. + +We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.cudnn.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision` +is overridden to `ieee`. + +We suggest to use the new settings for better control. And we do not support to use mix of old and new settings. + +.. warning:: + + Old settings with `allow_tf32` as follows is going to be deprecated. We suggest to use the above new settings for + better control. And we do not support to use mix of old and new settings. + Starting in PyTorch 1.7, there is a new flag called `allow_tf32`. This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores, @@ -133,44 +176,6 @@ To toggle the TF32 flags off in C++, you can do at::globalContext().setAllowTF32CuBLAS(false); at::globalContext().setAllowTF32CuDNN(false); -After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way. -We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. - -.. code:: python - - torch.backends.fp32_precision = "ieee" - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cudnn.rnn.fp32_precision = "tf32" - -The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`. -`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision. -`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision. - -We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`. - -.. code:: python - - torch.backends.cudnn.fp32_precision = "tf32" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" - -We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`. - -.. code:: python - - torch.backends.fp32_precision = "tf32" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" - -For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision` -is overridden to `ieee`. - -Old settings are still supported. But we suggest to use the new settings for better control. And we do not support -to use mix of old and new settings. - For more information about TF32, see: - `TensorFloat-32`_ diff --git a/docs/source/notes/mkldnn.rst b/docs/source/notes/mkldnn.rst index 48ee9ce84c35a..366c2f99cd6f2 100644 --- a/docs/source/notes/mkldnn.rst +++ b/docs/source/notes/mkldnn.rst @@ -26,7 +26,7 @@ Users can disable MKLDNN backend by: Bfloat16 (BF16) on MKLDNN backend --------------------------------------------------- -Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision +Starting in PyTorch 2.9, there is a set of APIs to control the internal computation precision for `float32` operators. .. code:: python diff --git a/docs/source/notes/numerical_accuracy.rst b/docs/source/notes/numerical_accuracy.rst index 2e081a08442d9..8944ecc05f277 100644 --- a/docs/source/notes/numerical_accuracy.rst +++ b/docs/source/notes/numerical_accuracy.rst @@ -93,8 +93,8 @@ On Ampere (and later) Nvidia GPUs, PyTorch can use TensorFloat32 (TF32) to speed When an operation is performed using TF32 tensor cores, only the first 10 bits of the input mantissa are read. This may reduce accuracy and produce surprising results (e.g., multiplying a matrix by the identity matrix may produce results that are different from the input). By default, TF32 tensor cores are disabled for matrix multiplications and enabled for convolutions, although most neural network workloads have the same convergence behavior when using TF32 as they have with fp32. -We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.allow_tf32 = True`` if your network does not need full float32 precision. -If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.allow_tf32 = False``. +We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.fp32_precision = "tf32"`` (```torch.backends.cuda.matmul.allow_tf32 = True`` is going to be deprecated) if your network does not need full float32 precision. +If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.conv.fp32_precision = "ieee"`` (``torch.backends.cudnn.allow_tf32 = False`` is going to be deprecated). For more information see :ref:`TensorFloat32`. From 9768d393fa62df8a508136f5b6634bf955d8365d Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 16 Jul 2025 01:52:05 +0000 Subject: [PATCH 086/457] add sfdp pattern (#155792) add sfdp pattern for MBartForCausalLM/PLBartForCausalLM in transformers==4.44.2. Improve the inference performance of these model. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155792 Approved by: https://github.com/Valentine233, https://github.com/jansel --- test/inductor/test_fused_attention.py | 44 +++++ torch/_inductor/fx_passes/fuse_attention.py | 44 ++++- .../serialized_patterns/_sfdp_pattern_24.py | 153 ++++++++++++++++++ 3 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 4d52775ccbade..9015332f4e15d 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -1065,6 +1065,44 @@ def dot_prod_attention( check_train=False, ) + def _test_sdpa_rewriter_24(self): + def dot_prod_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + embed_dim = query.size(3) + q = query.view(bs * n_head, seq_len, embed_dim) + k = key.reshape(bs * n_head, seq_len, embed_dim) + v = value.reshape(bs * n_head, seq_len, embed_dim) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, seq_len) + attn_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, seq_len) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, embed_dim) + return attn_output + + tensor_shape = (4, 2, 16, 32) + attn_mask = torch.randn((1, 1, 16, 16), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): @@ -1133,6 +1171,9 @@ class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) + test_sdpa_rewriter_24_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 + ) class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests): use_static_shapes = False @@ -1199,6 +1240,9 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) + test_sdpa_rewriter_24_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 + ) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 4ed950afe9a18..3e8bd56b32140 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -18,7 +18,6 @@ log = logging.getLogger(__name__) aten = torch.ops.aten - _scaled_dot_product_attention = aten.scaled_dot_product_attention @@ -582,6 +581,42 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): ) +def _sfdp_pattern_24(query, key, value, attention_mask): + """ + this pattern is for MBartForCausalLM/PLBartForCausalLM. + attn_mask has a differnt dtype with QKV. + there is no scale in sdpa. + """ + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + head_size = query.size(3) + q = query.view(bs * n_head, -1, head_size) + k = key.reshape(bs * n_head, -1, head_size) + v = value.reshape(bs * n_head, -1, head_size) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + if query.dtype == torch.half: + attn_weights = attn_weights.to(torch.half) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, head_size) + return attn_output + + +def _sfdp_replacement_24(query, key, value, attention_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask.to(dtype=query.dtype), + is_causal=False, + scale=1, + ) + + def _sfdp_pattern_21(query, key, value, attn_mask): # for T5 with inplace add query = query.permute([0, 2, 1, 3]) @@ -1003,6 +1038,13 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), + ( + _sfdp_pattern_24, + _sfdp_replacement_24, + [g(), g(), g(), b_float()], + {}, + _sfdp_extra_check, + ), ] mask_fp32_patterns = ["pattern_16"] if dtype == torch.half: diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py new file mode 100644 index 0000000000000..72f23373c143e --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py @@ -0,0 +1,153 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=4) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +mul_Tensor = CallFunction(aten.mul.Tensor, bmm_default_2, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, view_default_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, div_Tensor, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +_sfdp_pattern_24_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_1, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, convert_element_type_default, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_half_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +_sfdp_pattern_24_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) From 584a0510b30b2472e54197d6b67b6f5f5e8ac807 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 16 Jul 2025 01:54:31 +0000 Subject: [PATCH 087/457] [inductor] fix windows path for fresh cache. (#158324) `normalize_path_separator` for windows path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158324 Approved by: https://github.com/jansel --- torch/_inductor/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d22d67cecff21..3d427fd7dd044 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1099,13 +1099,17 @@ def fresh_cache( """ clear_caches() - inductor_cache_dir = tempfile.mkdtemp(dir=dir) + from torch._inductor.cpp_builder import normalize_path_separator + + inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir)) try: with mock.patch.dict( os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} ): log.debug("Using inductor cache dir %s", inductor_cache_dir) - triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + triton_cache_dir = normalize_path_separator( + os.path.join(inductor_cache_dir, "triton") + ) with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): yield if isinstance(cache_entries, dict): From 0cb36e2d62c811fcddea4c6d28b1c65246cdd160 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 16 Jul 2025 02:15:32 +0000 Subject: [PATCH 088/457] cache dict and string rep for better perf (#158372) Summary: NodeSouce should not be updated after created, so that it would be better if we cache its dict and string representation for better perf. Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78298501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158372 Approved by: https://github.com/yushangdi --- torch/fx/traceback.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 97391d567aba8..e57e89ea8d4b5 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -80,6 +80,10 @@ def __init__( self.node_info = None self.from_node = [] + # cache the action string and dict representation for performance. + self._action_string = None + self._dict = None + @property def name(self) -> str: return self.node_info.name if self.node_info else "" @@ -96,7 +100,9 @@ def __repr__(self): return self.print_readable() def _get_action_string(self): - return "+".join([a.name.lower() for a in self.action]) + if self._action_string is None: + self._action_string = "+".join([a.name.lower() for a in self.action]) + return self._action_string def print_readable(self, indent=0): if indent > 9: @@ -112,16 +118,19 @@ def print_readable(self, indent=0): return result def to_dict(self) -> dict: - # Convert the object to a dictionary - action_string = self._get_action_string() - return { - "name": self.name, - "target": self.target, - "graph_id": self.graph_id, - "pass_name": self.pass_name, - "action": action_string, - "from_node": [node.to_dict() for node in self.from_node], - } + if self._dict is None: + # Convert the object to a dictionary + action_string = self._get_action_string() + self._dict = { + "name": self.name, + "target": self.target, + "graph_id": self.graph_id, + "pass_name": self.pass_name, + "action": action_string, + "from_node": [node.to_dict() for node in self.from_node], + } + + return self._dict def __eq__(self, other: object): if not isinstance(other, NodeSource): From 5b0df2565ebb9677f34ae8eeb8c08fafaeacb15d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 12:04:35 -0700 Subject: [PATCH 089/457] Pipeline _create_aot_dispatcher_function (#158173) Two main things of note: - Review this diff without whitespace changes - To ensure that context managers correctly propagate to later pipeline stages, I am using the ExitStack trick: there is an ExitStack which is in scope for the entire pipeline, and inside of the individual pipeline stages we push context managers onto this stack when we want them to survive into the next pipeline stage. This is not obviously what the best final form of the code is, but create_aot_dispatcher_function is called from multiple locations so I can't just inline the context managers into the call site. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158173 Approved by: https://github.com/jamesjwu, https://github.com/wconstab ghstack dependencies: #158149, #158150 --- torch/_functorch/aot_autograd.py | 383 ++++++++++++++++--------------- 1 file changed, 202 insertions(+), 181 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 3803333948dfe..94609a7441417 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -572,19 +572,29 @@ def create_aot_dispatcher_function( fake_mode: FakeTensorMode, shape_env: Optional[ShapeEnv], ) -> tuple[Callable, ViewAndMutationMeta]: - with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True): - return _create_aot_dispatcher_function( - flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + with contextlib.ExitStack() as stack: + compiler_fn, flat_fn, dup_fake_flat_args, aot_config, fw_metadata = ( + _create_aot_dispatcher_function( + stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + ) + compiled_fn, fw_metadata = compiler_fn( + flat_fn, + dup_fake_flat_args, + aot_config, + fw_metadata=fw_metadata, ) + return compiled_fn, fw_metadata def _create_aot_dispatcher_function( + stack, flat_fn, fake_flat_args: FakifiedFlatArgs, aot_config: AOTConfig, fake_mode: FakeTensorMode, shape_env: Optional[ShapeEnv], -) -> tuple[Callable, ViewAndMutationMeta]: +) -> tuple[Callable, Callable, list[Any], AOTConfig, ViewAndMutationMeta]: """ Traces the forward and backward graphs of the attr:`flat_fn` to generate a joint graph. The joint graph is an Fx graph with Aten ops. Please refer to @@ -607,6 +617,10 @@ def _create_aot_dispatcher_function( When aot_config.is_export is False, we return an ordinary runtime function """ + stack.enter_context( + dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True) + ) + # This is the main entry point. # TODO: Chillee argues that dynamo itself should pass in fake tensors to # the list of arguments when compiling; at the moment we do not do this @@ -637,210 +651,207 @@ def _create_aot_dispatcher_function( # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function # that we generate in torch.compile. - with ( - torch.autograd.set_multithreading_enabled(False), - preserve_rng_state(), - fake_mode, - python_dispatcher_mode, - PhiloxStateTracker(), - torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), - ): - from torch._library.fake_class_registry import ( - FakeScriptObject, - maybe_to_fake_obj, - ) + stack.enter_context(torch.autograd.set_multithreading_enabled(False)) + stack.enter_context(preserve_rng_state()) + stack.enter_context(fake_mode) + stack.enter_context(python_dispatcher_mode) + stack.enter_context(PhiloxStateTracker()) + stack.enter_context( + torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing() + ) - # Tracing may mutate the states the fake script object, - # so we need to duplicate the fake script objects so that subsequent tracing - # won't be affected. - def _dup_fake_script_obj(fake_flat_args): - return [ - maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) - if isinstance(arg, FakeScriptObject) - else arg - for arg in fake_flat_args - ] + from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj + + # Tracing may mutate the states the fake script object, + # so we need to duplicate the fake script objects so that subsequent tracing + # won't be affected. + def _dup_fake_script_obj(fake_flat_args): + return [ + maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) + if isinstance(arg, FakeScriptObject) + else arg + for arg in fake_flat_args + ] + + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) + ) - needs_autograd = any( - x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) - ) + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + mod = root_module_when_exporting_non_strict(flat_fn) + if mod is not None: + ctx = _detect_attribute_assignment(mod) + else: + ctx = nullcontext() - with enable_python_dispatcher(): - # Patch set_rng_state as set_rng_state with fake tensors is - # nonsensical. This does not affect the collection of metadata. - with patch("torch.cuda.set_rng_state", lambda *args: None): - mod = root_module_when_exporting_non_strict(flat_fn) - if mod is not None: - ctx = _detect_attribute_assignment(mod) - else: - ctx = nullcontext() + if torch._functorch.config.fake_tensor_propagate_real_tensors: + # Running dynamo_timed causes fake tensor issues when + # propagate real tensor is switched on. + dynamo_timed_ctx = nullcontext() + else: + dynamo_timed_ctx = dynamo_timed( + "aot_collect_metadata", log_pt2_compile_event=True + ) - if torch._functorch.config.fake_tensor_propagate_real_tensors: - # Running dynamo_timed causes fake tensor issues when - # propagate real tensor is switched on. - dynamo_timed_ctx = nullcontext() - else: - dynamo_timed_ctx = dynamo_timed( - "aot_collect_metadata", log_pt2_compile_event=True - ) + with dynamo_timed_ctx, ctx: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + is_export=aot_config.is_export, + )(*_dup_fake_script_obj(fake_flat_args)) + + req_subclass_dispatch = requires_subclass_dispatch( + fake_flat_args, fw_metadata + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + output_and_mutation_safe = not any( + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) - with dynamo_timed_ctx, ctx: + if needs_autograd and output_and_mutation_safe: + # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: fw_metadata = run_functionalized_fw_and_collect_metadata( flat_fn, - static_input_indices=aot_config.static_input_indices, keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=needs_autograd, + is_train=False, pre_dispatch=aot_config.pre_dispatch, - is_export=aot_config.is_export, - )(*_dup_fake_script_obj(fake_flat_args)) - - req_subclass_dispatch = requires_subclass_dispatch( - fake_flat_args, fw_metadata - ) - CompileEventLogger.try_add_pt2_compile( - "backend_compile", requires_subclass_dispatch=req_subclass_dispatch - ) - - output_and_mutation_safe = not any( - x.requires_grad - # view-type operations preserve requires_grad even in no_grad. - # Do not count aliases of inputs with requires_grad as reason to make a training graph, - # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, - # setting their grad_fn properly. - and not ( - x.output_type - in (OutputType.alias_of_input, OutputType.is_input) - and fw_metadata.input_info[x.base_idx].requires_grad + static_input_indices=aot_config.static_input_indices, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=False, + tokens=fw_metadata.tokens, + static_input_indices=fw_metadata.static_input_indices, ) - for x in fw_metadata.output_info - ) and not any( - x.requires_grad - and x.mutates_data - and not x.mutations_under_no_grad_or_inference_mode - and not x.mutations_hidden_from_autograd - for x in fw_metadata.input_info - ) - if needs_autograd and output_and_mutation_safe: - # We realized that none of the outputs require grad, - # and none of the inputs that require grad are mutated. - # so we actually have an inference graph. - needs_autograd = False - # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta - # changes depending on whether we pass in is_train / keep_input_mutations, - # so we're forced to recompute the metadata. - # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata - # so that this is unnecessary. - if req_subclass_dispatch: - fw_metadata = run_functionalized_fw_and_collect_metadata( - flat_fn, - keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=False, - pre_dispatch=aot_config.pre_dispatch, - static_input_indices=aot_config.static_input_indices, - )(*fake_flat_args) - else: - fw_metadata = ViewAndMutationMeta( - input_info=fw_metadata.input_info, - output_info=fw_metadata.output_info, - num_intermediate_bases=fw_metadata.num_intermediate_bases, - keep_input_mutations=aot_config.keep_inference_input_mutations, - traced_tangents=fw_metadata.traced_tangents, - subclass_inp_meta=fw_metadata.subclass_inp_meta, - subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, - subclass_tangent_meta=fw_metadata.subclass_tangent_meta, - is_train=False, - tokens=fw_metadata.tokens, - static_input_indices=fw_metadata.static_input_indices, - ) - - if fw_metadata.num_intermediate_bases > 0: - assert not req_subclass_dispatch, f"""\ + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ torch.compile is currently being used with tensor subclass inputs: {",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs that alias one another, which is currently unsupported in the subclass use case. If you run into this, please file a github issue""" - if aot_config.is_export: - # aot_export: ban input metadata mutations for now to keep shared code paths simpler. - # Keeping .resize_() in the graph will require some work - # Allowing it but keeping the graph functional will require some calling convention changes. - if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: - raise RuntimeError( - f"""\ + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError( + f"""\ Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" - ) - # In export, banning data mutations on inputs that require grad for now. - # This should be rare, and is tricky to get right. When we trace the backward, - # we currently trace with autograd.grad instead of .backward(), which makes it difficult - # to ensure that we run autograd all the way through the input **before** it saw the mutation. - if ( - len( - [ - x - for x in fw_metadata.input_info - if x.requires_grad and x.mutates_data - ] - ) - != 0 - ): - raise RuntimeError( - f"""\ + ) + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if ( + len( + [ + x + for x in fw_metadata.input_info + if x.requires_grad and x.mutates_data + ] + ) + != 0 + ): + raise RuntimeError( + f"""\ Found a graph input that requires gradients, and received a mutation. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" - ) - if req_subclass_dispatch: - raise RuntimeError( - """\ + ) + if req_subclass_dispatch: + raise RuntimeError( + """\ aot_export is not currently supported with traceable tensor subclass. If you need this feature, please comment on """ - ) + ) - # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, - # and turning it on will require a non-trivial calling convention change for any export runtime. - if config.functionalize_rng_ops: - raise RuntimeError( - """\ + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError( + """\ Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" - ) + ) - def choose_dispatcher(needs_autograd, aot_config): - """ - Pick a dispatcher based on the config rules. - """ - if aot_config.is_export: - # export uses just the "graph bits", whereas the other - # two dispatchers include some extra work around handling a runtime epilogue - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="export" - ) - return partial(aot_dispatch_export, needs_autograd=needs_autograd) - elif needs_autograd and not aot_config.pre_dispatch: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - return aot_dispatch_autograd - else: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - return aot_dispatch_base + def choose_dispatcher(needs_autograd, aot_config): + """ + Pick a dispatcher based on the config rules. + """ + if aot_config.is_export: + # export uses just the "graph bits", whereas the other + # two dispatchers include some extra work around handling a runtime epilogue + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="export" + ) + return partial(aot_dispatch_export, needs_autograd=needs_autograd) + elif needs_autograd and not aot_config.pre_dispatch: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + return aot_dispatch_autograd + else: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + return aot_dispatch_base - compiler_fn = choose_dispatcher(needs_autograd, aot_config) + compiler_fn = choose_dispatcher(needs_autograd, aot_config) - compiled_fn, fw_metadata = compiler_fn( - flat_fn, - _dup_fake_script_obj(fake_flat_args), - aot_config, - fw_metadata=fw_metadata, - ) - return compiled_fn, fw_metadata + return ( + compiler_fn, + flat_fn, + _dup_fake_script_obj(fake_flat_args), + aot_config, + fw_metadata, + ) def aot_function( @@ -1203,12 +1214,22 @@ def aot_module_simplified( stack.enter_context(compiled_autograd._disable()) - compiled_fn, _ = create_aot_dispatcher_function( - functional_call, - fake_flat_args, + compiler_fn, flat_fn, dup_fake_flat_args, aot_config, fw_metadata = ( + _create_aot_dispatcher_function( + stack, + functional_call, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + ) + + compiled_fn, _ = compiler_fn( + flat_fn, + dup_fake_flat_args, aot_config, - fake_mode, - shape_env, + fw_metadata=fw_metadata, ) break From 84dec060b79078e68dcb6be1e7f308dad05d00e2 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 12:04:36 -0700 Subject: [PATCH 090/457] Hoist choose_dispatcher to top level, remove unnecessary returns (#158176) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158176 Approved by: https://github.com/jamesjwu ghstack dependencies: #158149, #158150, #158173 --- torch/_functorch/aot_autograd.py | 57 ++++++++++++++++---------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 94609a7441417..5fb3ed46c840d 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -573,10 +573,8 @@ def create_aot_dispatcher_function( shape_env: Optional[ShapeEnv], ) -> tuple[Callable, ViewAndMutationMeta]: with contextlib.ExitStack() as stack: - compiler_fn, flat_fn, dup_fake_flat_args, aot_config, fw_metadata = ( - _create_aot_dispatcher_function( - stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env - ) + compiler_fn, dup_fake_flat_args, fw_metadata = _create_aot_dispatcher_function( + stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env ) compiled_fn, fw_metadata = compiler_fn( flat_fn, @@ -821,39 +819,38 @@ def _dup_fake_script_obj(fake_flat_args): or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" ) - def choose_dispatcher(needs_autograd, aot_config): - """ - Pick a dispatcher based on the config rules. - """ - if aot_config.is_export: - # export uses just the "graph bits", whereas the other - # two dispatchers include some extra work around handling a runtime epilogue - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="export" - ) - return partial(aot_dispatch_export, needs_autograd=needs_autograd) - elif needs_autograd and not aot_config.pre_dispatch: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - return aot_dispatch_autograd - else: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - return aot_dispatch_base - compiler_fn = choose_dispatcher(needs_autograd, aot_config) return ( compiler_fn, - flat_fn, _dup_fake_script_obj(fake_flat_args), - aot_config, fw_metadata, ) +def choose_dispatcher(needs_autograd, aot_config): + """ + Pick a dispatcher based on the config rules. + """ + if aot_config.is_export: + # export uses just the "graph bits", whereas the other + # two dispatchers include some extra work around handling a runtime epilogue + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="export" + ) + return partial(aot_dispatch_export, needs_autograd=needs_autograd) + elif needs_autograd and not aot_config.pre_dispatch: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + return aot_dispatch_autograd + else: + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + return aot_dispatch_base + + def aot_function( fn: Callable, fw_compiler: Callable, @@ -1214,7 +1211,9 @@ def aot_module_simplified( stack.enter_context(compiled_autograd._disable()) - compiler_fn, flat_fn, dup_fake_flat_args, aot_config, fw_metadata = ( + flat_fn = functional_call + + compiler_fn, dup_fake_flat_args, fw_metadata = ( _create_aot_dispatcher_function( stack, functional_call, From 49d0332cef68be82eed713e265800f82d6cbff75 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 12:04:36 -0700 Subject: [PATCH 091/457] Introduce stages to aot_dispatch (#158213) The starting point for this refactor is that I need access to the fully general joint graph representation in an export-like interface, but I then subsequently need a way to feed this joint graph into the rest of the compilation pipeline so I can get an actual callable that I can run once I've finished modifying it. Previously, people had added export capabilities to AOTAutograd by having an export flag that toggled what exactly the functions return and triggering aot_dispatch to go to a different "export" implementation, but I've found this difficult to understand and has lead to a bit of duplicate code for the export path. So the idea here is to reorganize the structure of the function calls in AOTAutograd. Here, it is helpful to first describe how things used to work: * Start with aot_autograd.py top level functions like aot_function, _aot_export_function and aot_module_simplified. These call: * create_aot_dispatcher_function. This does a bunch of stuff (forward metadata collection) and adds many context managers. This calls: * One of aot_dispatch_base, aot_dispatch_export or aot_dispatch_autograd, which: * Call aot_dispatch_autograd_graph or aot_dispatch_base_graph to actually do the graph capture * Do some base/export/autograd specific post-processing on the graph Notice the pattern of nested function invocations means that there is no way to easily get the graph capture result from the autograd case; furthermore, the export path is "bolted" on to force the entire chain of functions to have a different return result than normal, and no way to *resume* the rest of the post-processing to actually get a callable. Here is the new structure: * Start with aot_autograd.py top level functions like aot_function, _aot_export_function and aot_module_simplified. These now orchestrate this top level flow: * Start a context manager (stack); this stateful context block takes care of all of the nested context managers which originally necessitated the nested call structure * Call create_aot_state to do initial setup and setup all the context managers on stack. These context managers do NOT exit upon return of this. * Call aot_stage1_graph_capture to do the graph capture * Call aot_stage2_compile or aot_stage2_export depending on what postprocessing you want With this new structure, it's now possible (although not done in this PR) to return the graph after aot_stage1_graph_capture and do something with it, before running aot_stage2_compile to finish the job. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158213 Approved by: https://github.com/jamesjwu ghstack dependencies: #158149, #158150, #158173, #158176 --- .../jit_compile_runtime_wrappers.py | 150 +++++++++++------- .../_aot_autograd/runtime_wrappers.py | 49 +----- torch/_functorch/_aot_autograd/schemas.py | 130 +++++++++++++++ torch/_functorch/aot_autograd.py | 136 ++++++---------- 4 files changed, 269 insertions(+), 196 deletions(-) diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 53bfa1e3c51eb..73d6ab1c19596 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -18,13 +18,18 @@ import traceback from collections import defaultdict from contextlib import nullcontext -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code +from torch._dynamo.utils import ( + CompileEventLogger, + detect_fake_mode, + dynamo_timed, + lazy_format_graph_code, +) from torch._guards import CompileContext, TracingContext from torch._logging import getArtifactLogger, trace_structured from torch._subclasses import FakeTensor @@ -67,7 +72,13 @@ pre_compile, RuntimeWrapper, ) -from .schemas import AOTConfig, MutationType, ViewAndMutationMeta +from .schemas import ( + AOTConfig, + AOTGraphCapture, + AOTState, + MutationType, + ViewAndMutationMeta, +) from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta from .utils import ( _get_symint_hints, @@ -92,6 +103,7 @@ # Returns a Callable and a ViewAndMutationMeta. # Currently, only export needs the ViewAndMutationMeta after this function. +# TODO: Refactor this DispatchReturn = tuple[Callable, ViewAndMutationMeta] @@ -102,46 +114,68 @@ def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper] return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] -# Export's dispatching logic is unique in a few ways: it only needs the "graph" -# bits of aot_autograd, and doesn't need to do any specific wrapping. -def aot_dispatch_export( +def aot_stage1_graph_capture( + aot_state: AOTState, flat_fn: Callable, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, - needs_autograd: bool, -) -> DispatchReturn: - wrappers = _create_wrappers_for_dispatch(needs_autograd) - flat_fn, flat_args, fw_metadata = pre_compile( +) -> AOTGraphCapture: + aot_config = aot_state.aot_config + + wrappers = _create_wrappers_for_dispatch(aot_state.needs_autograd) + flat_fn, aot_state.flat_args, aot_state.fw_metadata = pre_compile( wrappers, flat_fn, - flat_args, + aot_state.flat_args, aot_config, - fw_metadata=fw_metadata, + fw_metadata=aot_state.fw_metadata, ) - if needs_autograd and not aot_config.pre_dispatch: - graph, _, _ = aot_dispatch_autograd_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + # NB: This is currently only used for backwards, where fwd/bwd + # deterministic TLS can be different + aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + if aot_state.needs_autograd and not aot_config.pre_dispatch: + # FYI: this being moved to trigger in export is new, seems fine! + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + graph, updated_flat_args, maybe_subclass_meta = aot_dispatch_autograd_graph( + flat_fn, + aot_state.flat_args, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) else: - graph, _, _ = aot_dispatch_base_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + graph, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( + flat_fn, aot_state.flat_args, aot_config, fw_metadata=aot_state.fw_metadata ) + return AOTGraphCapture( + wrappers=wrappers, + graph=graph, + updated_flat_args=updated_flat_args, + maybe_subclass_meta=maybe_subclass_meta, + ) + + +def aot_stage2_export( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture +) -> DispatchReturn: + graph = aot_graph_capture.graph + aot_config = aot_state.aot_config + wrappers = aot_graph_capture.wrappers + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="export") + # NB: the wrappers that run in pre_compile for export are # either a no-op, because they're not needed, or will raise a runtime error, # since they don't support export. # We still run these wrappers to make sure that they're not needed pre compile, # but we technically don't need to run them post compile at all here. - compiled_fn, fw_metadata = post_compile( - wrappers, graph, aot_config, runtime_metadata=fw_metadata + compiled_fn, aot_state.fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=aot_state.fw_metadata ) # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph # (either a joint or an inference-only graph) assert isinstance(compiled_fn, torch.fx.GraphModule) - return compiled_fn, fw_metadata + return compiled_fn, aot_state.fw_metadata def sanitize_aot_config(input: AOTConfig) -> AOTConfig: @@ -166,23 +200,33 @@ def sanitize_aot_config(input: AOTConfig) -> AOTConfig: ) -def aot_dispatch_base( - flat_fn, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, +def aot_stage2_compile( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch: + return aot_stage2_autograd(aot_state, aot_graph_capture) + else: + return aot_stage2_inference(aot_state, aot_graph_capture) + + +def aot_stage2_inference( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, ) -> DispatchReturn: """ Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. """ - wrappers = _create_wrappers_for_dispatch(needs_autograd=False) - flat_fn, flat_args, fw_metadata = pre_compile( - wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) - fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + fw_module = aot_graph_capture.graph + wrappers = aot_graph_capture.wrappers + updated_flat_args = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference") + # Save the forward_graph_str right after aot_dispatch_base_graph, # to save in the cache aot_forward_graph_str = None @@ -1247,31 +1291,23 @@ def _log_structured_logs(): bw_module.recompile() -def aot_dispatch_autograd( - flat_fn, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, +def aot_stage2_autograd( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture ) -> DispatchReturn: """ Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, and returns a wrapped torch.autograd.Function with a forward and backward. """ - wrappers = _create_wrappers_for_dispatch(needs_autograd=True) - flat_fn, flat_args, fw_metadata = pre_compile( - wrappers, - flat_fn, - flat_args, - aot_config, - fw_metadata=fw_metadata, - ) - fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() - with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): - fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + wrappers = aot_graph_capture.wrappers + fx_g = aot_graph_capture.graph + flat_args = aot_state.flat_args + joint_inputs = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd") # Copied from aot_dispatch_autograd_graph. disable_amp = torch._C._is_any_autocast_enabled() diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 77eebd5e62480..46bd0ad774793 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -49,6 +49,7 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling from .schemas import ( AOTConfig, + CompilerWrapper, InputAliasInfo, MemoryFormatMeta, MutationType, @@ -80,54 +81,6 @@ zip = strict_zip -class CompilerWrapper: - """ - A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: - - 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) - 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) - - Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate - caching on the compiled output, and re-wrapping the output via epilogues. - Extra metadata that is needed to compute pre or post compile can be passed in via attributes. - """ - - def pre_compile( - self, - flat_fn, - flat_args: list[Tensor], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: - """ - Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. - Args: - flat_fn: The function to compile - flat_args: Metadata from example inputs of the function to compile - aot_config: AOTConfig passed in at compile time - fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args - """ - return flat_fn, flat_args, fw_metadata - - def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: - """ - Given an output of the compiler, wrap it with information received from prologue. - Args: - compiled_fn: Callable after calling compiler_fn - aot_config: AOTConfig after calling prologue - runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. - Example: - - def wrapped_compiled_fn(args): - # do something with args, aot_config, fw_metadata - return compiled_fn(args) - - return wrapped_compiled_fn - """ - return compiled_fn - - # The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic # that needs to run after the compiled function. # diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 9b32398233032..38185851d8faa 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -5,6 +5,7 @@ """ import collections +import contextlib import dataclasses import functools import itertools @@ -15,6 +16,7 @@ import torch import torch.utils._pytree as pytree +from torch import Tensor from torch._guards import Source from torch._ops import OpOverload from torch._subclasses import FakeTensor @@ -966,3 +968,131 @@ def __post_init__(self): "SubclassTracingInfo", ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], ) + + +@dataclass +class AOTState: + """ + When we run AOTAutograd, this class encapsulates the state in the compiler which + must be preserved across stages. This is state in the traditional sense (not an + environment) because some values in this structure change as we progress through + pipelines in AOTAutograd. + """ + + # Whether or not we need to handle autograd when doing graph capture and + # compilation. Although the calling convention for non-autograd graph + # capture in AOTAutograd is simple and can be relied upon, the autograph + # capture calling convention is quite complicated and in general you are + # only expected to pass to aot_stage2_compile to process. + needs_autograd: bool + + # The FAKE flat arguments which we will do tracing with. Although you + # might naively expect this to be immutable, it's not: when we perform + # tracing, we may execute code that modifies the metadata of inputs, + # causing the args to become "invalid". It's also nontrivial to have a + # "golden" set of fake values and deepcopy them just in time when you + # might destructively mutate them (Voz and I tried very hard to do this). + # So we just periodically renew this field. Don't worry too much about + # this unless you're specifically trying to track down an input metadata + # mutation bug. + # + # (By the way, this is NEVER the joint inputs! Those only ever go in + # AOTGraphCapture) + flat_args: list[Any] + + # This contains view and mutation information about the function, which we + # detected by doing an initial trace when we created this state. + fw_metadata: ViewAndMutationMeta + + # Top-level configuration + # This is morally immutable but sometimes we are naughty and mutate it. + aot_config: AOTConfig + + # When performing AOTAutograd traces and other passes, we typically + # require a lot of active context managers; most typically these either + # (1) ensure we are faithfully replicating the original PyTorch context + # managers or (2) toggle some behaviors in PyTorch to make it more + # suitable for tracing. When you use AOTState, you're expected to have + # created an ExitStack, entered it; then while we are running AOTAutograd + # we will add things onto the stack as necessary. When you're all done + # with processing AOTAutograd, you can exit this stack. All functions + # that take AOTState expect the ExitStack to not have been exited yet. + # + # TODO: We potentially could offer a resumable context manager, where you + # can cancel it and reenable it later when you need it. + stack: contextlib.ExitStack + + +class CompilerWrapper: + """ + A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: + + 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) + 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) + + Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate + caching on the compiled output, and re-wrapping the output via epilogues. + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +@dataclass +class AOTGraphCapture: # Produced by aot_stage1_graph_capture + # AOTAutograd typically operates by taking complicated graphs and + # desugaring them into simpler graphs that use PyTorch features. These + # wrappers establish invariants so that when we actually do tracing we can + # assume these invariants hold, leading to a simpler tracing + # implementation. However, this means that we have to keep track of how + # to enter/exit these wrappers when passing inputs into the compiled + # graph, among other things! + wrappers: list[CompilerWrapper] + + # The actual captured graph. In some circumstances (export) this graph + # has a specific calling convention that can be relied upon by external + # callers. In other situations, the calling convention is unspecified and + # only aot_stage2_compile knows how to deal with them. + graph: torch.fx.GraphModule + + # When compiling with autograd support, this is the joint_inputs, which is + # larger than the original flat_args as all tangents get inputs. The + # tuple organizes into primals and tangents. When not autograd it's just + # a plain list. + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + # Metadata about subclass inputs/outputs in the graph trace. + maybe_subclass_meta: Any diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 5fb3ed46c840d..c7bdbd870d175 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -4,7 +4,7 @@ import itertools from collections.abc import KeysView, Sequence from contextlib import contextmanager, nullcontext -from functools import partial, wraps +from functools import wraps from typing import Any, Callable, NewType, Optional, Protocol, TypeVar from unittest.mock import patch @@ -69,9 +69,9 @@ remove_dupe_metadata, ) from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 - aot_dispatch_autograd, - aot_dispatch_base, - aot_dispatch_export, + aot_stage1_graph_capture, + aot_stage2_compile, + aot_stage2_export, ) from ._aot_autograd.logging_utils import ( # noqa: F401 callback_set, @@ -93,6 +93,7 @@ ) from ._aot_autograd.schemas import ( # noqa: F401 AOTConfig, + AOTState, BackwardSignature, FQN, GraphInputName, @@ -565,34 +566,14 @@ def construct_fake_mode( return (fake_mode, shape_env) -def create_aot_dispatcher_function( +def create_aot_state( + stack: contextlib.ExitStack, flat_fn, fake_flat_args: FakifiedFlatArgs, aot_config: AOTConfig, fake_mode: FakeTensorMode, shape_env: Optional[ShapeEnv], -) -> tuple[Callable, ViewAndMutationMeta]: - with contextlib.ExitStack() as stack: - compiler_fn, dup_fake_flat_args, fw_metadata = _create_aot_dispatcher_function( - stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env - ) - compiled_fn, fw_metadata = compiler_fn( - flat_fn, - dup_fake_flat_args, - aot_config, - fw_metadata=fw_metadata, - ) - return compiled_fn, fw_metadata - - -def _create_aot_dispatcher_function( - stack, - flat_fn, - fake_flat_args: FakifiedFlatArgs, - aot_config: AOTConfig, - fake_mode: FakeTensorMode, - shape_env: Optional[ShapeEnv], -) -> tuple[Callable, Callable, list[Any], AOTConfig, ViewAndMutationMeta]: +) -> AOTState: """ Traces the forward and backward graphs of the attr:`flat_fn` to generate a joint graph. The joint graph is an Fx graph with Aten ops. Please refer to @@ -609,12 +590,10 @@ def _create_aot_dispatcher_function( inputs in flat_args are parameters and buffers, and the rest are inputs. We use this to assume that parameters/buffer's shapes don't change. - - Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) - When aot_config.is_export is True, we return an FX graph + metadata - When aot_config.is_export is False, we return an ordinary runtime function """ + # Old name for now to avoid messing with stats. Also, note this is pushed + # on the stack, so it extends BEYOND this function stack.enter_context( dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True) ) @@ -819,38 +798,16 @@ def _dup_fake_script_obj(fake_flat_args): or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" ) - compiler_fn = choose_dispatcher(needs_autograd, aot_config) - - return ( - compiler_fn, - _dup_fake_script_obj(fake_flat_args), - fw_metadata, + return AOTState( + needs_autograd=needs_autograd, + flat_args=_dup_fake_script_obj(fake_flat_args), + fw_metadata=fw_metadata, + # Packaging this just for later use + aot_config=aot_config, + stack=stack, ) -def choose_dispatcher(needs_autograd, aot_config): - """ - Pick a dispatcher based on the config rules. - """ - if aot_config.is_export: - # export uses just the "graph bits", whereas the other - # two dispatchers include some extra work around handling a runtime epilogue - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="export" - ) - return partial(aot_dispatch_export, needs_autograd=needs_autograd) - elif needs_autograd and not aot_config.pre_dispatch: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - return aot_dispatch_autograd - else: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - return aot_dispatch_base - - def aot_function( fn: Callable, fw_compiler: Callable, @@ -951,13 +908,12 @@ def returned_function(*args, **kwargs): fake_flat_args: FakifiedFlatArgs = process_inputs( flat_args, aot_config, fake_mode, shape_env ) - compiled_fn, _ = create_aot_dispatcher_function( - flat_fn, - fake_flat_args, - aot_config, - fake_mode, - shape_env, - ) + with contextlib.ExitStack() as stack: + aot_state = create_aot_state( + stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) + compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) cached_res = (compiled_fn, out_spec) cached_fn, out_spec = cached_res @@ -1211,25 +1167,16 @@ def aot_module_simplified( stack.enter_context(compiled_autograd._disable()) - flat_fn = functional_call - - compiler_fn, dup_fake_flat_args, fw_metadata = ( - _create_aot_dispatcher_function( - stack, - functional_call, - fake_flat_args, - aot_config, - fake_mode, - shape_env, - ) - ) - - compiled_fn, _ = compiler_fn( - flat_fn, - dup_fake_flat_args, + aot_state = create_aot_state( + stack, + functional_call, + fake_flat_args, aot_config, - fw_metadata=fw_metadata, + fake_mode, + shape_env, ) + aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) + compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) break if isinstance(mod, torch._dynamo.utils.GmWrapper): @@ -1415,6 +1362,8 @@ def fn_to_trace(*args): dynamic_shapes=dynamic_shapes, kwargs=kwargs, ) + + # TODO: subsume this path with the aot_stage2_graph_capture path if trace_joint: @wraps(functional_call) @@ -1645,13 +1594,18 @@ def _aot_export_function( shape_env = fake_mode.shape_env fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) - fx_g, meta = create_aot_dispatcher_function( - flat_fn, - fake_flat_args, - aot_config, - fake_mode, - shape_env, - ) + with contextlib.ExitStack() as stack: + aot_state = create_aot_state( + stack, + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) + fx_g, meta = aot_stage2_export(aot_state, aot_graph_capture) + return fx_g, meta, in_spec, out_spec.spec From 7637c9718aaf688ed515ab3ae4037f9194f0a018 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 12:04:36 -0700 Subject: [PATCH 092/457] Move functions from torch._functorch.aot_autograd that are not frontend functions to frontend_utils (#158251) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158251 Approved by: https://github.com/jamesjwu ghstack dependencies: #158149, #158150, #158173, #158176, #158213 --- torch/_dynamo/device_interface.py | 4 +- .../_aot_autograd/frontend_utils.py | 284 +++++++++++++++ torch/_functorch/_aot_autograd/schemas.py | 47 ++- torch/_functorch/aot_autograd.py | 328 +----------------- 4 files changed, 345 insertions(+), 318 deletions(-) create mode 100644 torch/_functorch/_aot_autograd/frontend_utils.py diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 2ec7c5f7259f1..9c6e4f6bf5f8b 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -247,8 +247,8 @@ def get_device_properties(device: torch.types.Device = None): synchronize = staticmethod(torch.cuda.synchronize) get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] - exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] - maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type, has-type] + maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type, has-type] memory_allocated = staticmethod(torch.cuda.memory_allocated) is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] diff --git a/torch/_functorch/_aot_autograd/frontend_utils.py b/torch/_functorch/_aot_autograd/frontend_utils.py new file mode 100644 index 0000000000000..55b84c12df829 --- /dev/null +++ b/torch/_functorch/_aot_autograd/frontend_utils.py @@ -0,0 +1,284 @@ +# mypy: ignore-errors + +from collections.abc import KeysView +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._guards import detect_fake_mode +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .schemas import AOTConfig, FakifiedFlatArgs + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), hint=x, source=source + ) + if isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + return None, [] + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name() if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + assigned_tensor_attributes = [] + + def _collect_assigned_tensor_attributes(kp, v, _v): + if _v is not v: + attr, *rest = kp + if isinstance(v, torch.Tensor): + assigned_tensor_attributes.append( + f"self.{attr.key}{pytree.keystr(rest)}" + ) + # TODO(avik): Assigning all other types are allowed right now. + # Maybe in the future we want to limit this to primitive types? + return v + + new_attrs = _get_attributes(mod) + if len(new_attrs) != len(snapshot): + added_attrs = new_attrs.keys() - snapshot.keys() + deleted_attrs = snapshot.keys() - new_attrs.keys() + + if len(added_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were created in the model.forward: {added_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + if len(deleted_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + pytree.tree_map_with_path( + _collect_assigned_tensor_attributes, snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + mod.__dict__.update(snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + raise ValueError( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 38185851d8faa..78f8e506e07e1 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -12,12 +12,14 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, NewType, Optional, Union +from typing import Any, Callable, NewType, Optional, Protocol, TypeVar, Union import torch import torch.utils._pytree as pytree from torch import Tensor from torch._guards import Source +from torch._inductor.output_code import OutputCode +from torch._inductor.utils import InputType from torch._ops import OpOverload from torch._subclasses import FakeTensor from torch._subclasses.fake_tensor import is_fake @@ -1096,3 +1098,46 @@ class AOTGraphCapture: # Produced by aot_stage1_graph_capture # Metadata about subclass inputs/outputs in the graph trace. maybe_subclass_meta: Any + + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + self.output_code_ty = output_code_ty + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index c7bdbd870d175..73eca634495a1 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -2,10 +2,9 @@ import contextlib import itertools -from collections.abc import KeysView, Sequence -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext from functools import wraps -from typing import Any, Callable, NewType, Optional, Protocol, TypeVar +from typing import Any, Callable, Optional from unittest.mock import patch import torch @@ -25,15 +24,10 @@ ) from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex -from torch._inductor.output_code import OutputCode -from torch._inductor.utils import BoxedBool, InputType +from torch._inductor.utils import BoxedBool from torch._subclasses import FakeTensor, FakeTensorMode -from torch.fx.experimental.proxy_tensor import ( - _pytree_subclasses_that_lose_info, - make_fx, -) +from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.utils._python_dispatch import is_traceable_wrapper_subclass static_inputs_log = torch._logging.getArtifactLogger( @@ -49,6 +43,12 @@ from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 run_functionalized_fw_and_collect_metadata, ) +from ._aot_autograd.frontend_utils import ( + _detect_attribute_assignment, + _try_get_metadata_from_dynamo, + construct_fake_mode, + process_inputs, +) from ._aot_autograd.functional_utils import ( # noqa: F401 _check_if_mutation_can_be_in_graph, are_all_mutations_hidden_from_autograd, @@ -93,8 +93,10 @@ ) from ._aot_autograd.schemas import ( # noqa: F401 AOTConfig, + AOTDispatchCompiler, AOTState, BackwardSignature, + FakifiedFlatArgs, FQN, GraphInputName, GraphOutputName, @@ -103,6 +105,7 @@ MutationType, OutputAliasInfo, OutputType, + SerializableAOTDispatchCompiler, SubclassCreationMeta, SubclassMeta, TensorAlias, @@ -441,130 +444,6 @@ aot_autograd_decompositions = {} -FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) - - -TOutputCode = TypeVar("TOutputCode", bound=OutputCode) - - -class AOTDispatchCompiler(Protocol): - """ - Represents a fw or bw_compiler passed to AOTAutograd. - """ - - def __call__( - self, - gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - ) -> Any: ... - - -# TODO: bikeshed on this name -class SerializableAOTDispatchCompiler(AOTDispatchCompiler): - """ - Represents an AOTDispatchCompiler that returns an OutputCode, and is - therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. - A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of - the kwargs in _CompileFxKwargs. - """ - - def __init__( - self, - output_code_ty: type[TOutputCode], - compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], - ): - self.output_code_ty = output_code_ty - self.compiler_fn = compiler_fn - - def __call__( - self, - gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - ) -> OutputCode: - return self.compiler_fn(gm, example_inputs) - - -def process_inputs( - flat_args: list[Any], - aot_config: AOTConfig, - fake_mode: FakeTensorMode, - shape_env: Optional[ShapeEnv], - ignore_shape_env: bool = False, -) -> FakifiedFlatArgs: - with fake_mode: - - def convert(idx, x): - if shape_env is not None and not ignore_shape_env: - from torch._dynamo.source import ConstantSource - - if isinstance(x, int): - # We always specialize on scalar values in export. - if aot_config.is_export: - return x - source = ConstantSource(f"sym_{idx}") - return shape_env.create_symintnode( - shape_env.create_symbol(x, source), hint=x, source=source - ) - if isinstance(x, torch.ScriptObject): - return torch._library.fake_class_registry.maybe_to_fake_obj( - fake_mode, x - ) - if not isinstance(x, torch.Tensor): - return x - if isinstance(x, FakeTensor): - assert x.fake_mode is fake_mode - return x - if is_traceable_wrapper_subclass(x): - attrs, _ = x.__tensor_flatten__() - if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): - assert all( - getattr(x, attr).fake_mode is fake_mode for attr in attrs - ) - return x - - # see note [Tensor Fakification and Symbol Caching] - symbolic_context = None - source = None - trace = True - if tracing_context := torch._guards.TracingContext.try_get(): - if x in tracing_context.tensor_to_context: - symbolic_context = tracing_context.tensor_to_context[x] - source = symbolic_context.tensor_source - # We already fakeified this tensor in Dynamo, don't - # dump the trace for it again - trace = False - if ( - idx < aot_config.num_params_buffers - and config.static_weight_shapes - and not symbolic_context - ): - # TODO: Ensure that this codepath is never exercised from - # Dynamo - return fake_mode.from_tensor(x, static_shapes=True) - - result = fake_mode.from_tensor( - x, - static_shapes=ignore_shape_env, - symbolic_context=symbolic_context, - source=source, - trace=trace, - ) - return result - - return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) - - -def construct_fake_mode( - flat_args: list[Any], aot_config: AOTConfig -) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: - fake_mode = detect_fake_mode(flat_args) - if fake_mode is None: - shape_env = ShapeEnv() if aot_config.dynamic_shapes else None - fake_mode = FakeTensorMode(shape_env=shape_env) - else: - shape_env = fake_mode.shape_env - return (fake_mode, shape_env) - def create_aot_state( stack: contextlib.ExitStack, @@ -975,87 +854,6 @@ def forward(self, *args, **kwargs): return AOTModule() -def _try_get_metadata_from_dynamo( - mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int -) -> tuple[Optional[list[torch._guards.Source]], list[int]]: - """ - Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. - We first verify that `mod` does come from Dynamo, then we handle cases where - metadata might be missing. - - Returns: - aot_autograd_arg_pos_to_source: used to dedup params and their guards - static_input_indices: used to identify static inputs for cudagraphs - """ - # Note [Assumption on Dynamo Metadata] - # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, - # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. - # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to - # be propagated in order to be recognized as a dynamo graph - - if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): - # graph was not captured by dynamo - return None, [] - - if not hasattr(mod, "_param_name_to_source"): - # is from export - return None, [] - - # We now know this came from dynamo, and (1) we care about guards, - # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards - # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. - # Additionally, we mark static indices for cudagraphs. - param_name_to_source = mod._param_name_to_source - seen_sources = set() - - aot_autograd_arg_pos_to_source = [] - static_input_indices = [] - # Collect the new inputs lifted by aotdispatch - for i, name in enumerate(param_keys): - assert name in param_name_to_source, f"{name} not found." - source = param_name_to_source[name] - assert source not in seen_sources, source - seen_sources.add(source) - aot_autograd_arg_pos_to_source.append(source) - - static_input_indices.append(i) - - # Collect the dynamo graph inputs - # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID - # matched tensors back into the Fx graph, this might not be necessary. - for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): - assert hasattr(node, "_dynamo_source") - source = node._dynamo_source - # `source`` specifies the source from user code. ddp optimizer may have - # intermediate values becoming submodule placeholders which does not - # have a source - assert source is None or source not in seen_sources, source - seen_sources.add(source) - aot_autograd_arg_pos_to_source.append(source) - source_name = source.name() if source else str(source) - - # input[i] in dynamo is now: - # input[i + len(extra_params)] in AOT, - # where extra_params are the params/buffers that dynamo baked into the - # OutputGraph - actual_pos = pos + len(param_keys) - - if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( - "_dynamo_static_input_type", None - ): - static_inputs_log.debug( - "Adding static input pos %s for source %s", actual_pos, source_name - ) - static_input_indices.append(actual_pos) - else: - static_inputs_log.debug( - "Non-static input pos %s for source %s", actual_pos, source_name - ) - - assert full_args_num == len(aot_autograd_arg_pos_to_source) - return aot_autograd_arg_pos_to_source, static_input_indices - - def aot_module_simplified( mod: nn.Module, args, @@ -1609,105 +1407,5 @@ def _aot_export_function( return fx_g, meta, in_spec, out_spec.spec -@contextmanager -def _detect_attribute_assignment(mod: torch.nn.Module): - # Do not allow assignment of tensor attributes during export unless - # the attribute is registered as a buffer. - - NN_MODULE_STD_ATTRS = [ - "_backward_hooks", - "_backward_pre_hooks", - "_buffers", - "_forward_hooks", - "_forward_hooks_always_called", - "_forward_hooks_with_kwargs", - "_forward_pre_hooks", - "_forward_pre_hooks_with_kwargs", - "_is_full_backward_hook", - "_load_state_dict_post_hooks", - "_load_state_dict_pre_hooks", - "_modules", - "_non_persistent_buffers_set", - "_parameters", - "_state_dict_hooks", - "_state_dict_pre_hooks", - "training", - ] - NN_MODULE_LAZY_STD_ATTRS = [ - "_initialize_hook", - "_load_hook", - ] - STD_ATTRS = { - *NN_MODULE_STD_ATTRS, - *NN_MODULE_LAZY_STD_ATTRS, - } - - def _get_attributes(mod): - # return any attributes of a module that are not standard attributes - return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} - - # save state of attributes before enter - snapshot = pytree.tree_map( - lambda x: x, - _get_attributes(mod), - is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, - ) - try: - yield - finally: - # after exit, compare state of attributes with snapshot - # to detect which tensor attributes were assigned - assigned_tensor_attributes = [] - - def _collect_assigned_tensor_attributes(kp, v, _v): - if _v is not v: - attr, *rest = kp - if isinstance(v, torch.Tensor): - assigned_tensor_attributes.append( - f"self.{attr.key}{pytree.keystr(rest)}" - ) - # TODO(avik): Assigning all other types are allowed right now. - # Maybe in the future we want to limit this to primitive types? - return v - - new_attrs = _get_attributes(mod) - if len(new_attrs) != len(snapshot): - added_attrs = new_attrs.keys() - snapshot.keys() - deleted_attrs = snapshot.keys() - new_attrs.keys() - - if len(added_attrs) > 0: - raise ValueError( - f"During torch.export, following attrs were created in the model.forward: {added_attrs} " - f"Such attributes must be registered as buffers using the `register_buffer` " - f"API and must be initialized at model.__init__ " - f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) - - if len(deleted_attrs) > 0: - raise ValueError( - f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " - f"Such attributes must be registered as buffers using the `register_buffer` " - f"API and must be initialized at model.__init__ " - f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) - - pytree.tree_map_with_path( - _collect_assigned_tensor_attributes, snapshot, new_attrs - ) - # restore state of all attributes (including, e.g., of primitive types) - mod.__dict__.update(snapshot) - - if assigned_tensor_attributes: - if len(assigned_tensor_attributes) > 1: - noun, verb = "attributes", "were" - else: - noun, verb = "attribute", "was" - raise ValueError( - f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " - "Such attributes must be registered as buffers using the `register_buffer` API " - "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) - - compiled_function = aot_function compiled_module = aot_module From e265b719bd67f7c0a2b9001daef442a70232dcc8 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Jul 2025 12:04:37 -0700 Subject: [PATCH 093/457] Extract out prepare_aot_module_simplified for use in next PR (#158319) Also a small amount of extra code cleanup. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158319 Approved by: https://github.com/jingsh ghstack dependencies: #158149, #158150, #158173, #158176, #158213, #158251 --- torch/_functorch/aot_autograd.py | 169 +++++++++++++++++++------------ 1 file changed, 103 insertions(+), 66 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 73eca634495a1..824fa1e0c25c8 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -854,29 +854,20 @@ def forward(self, *args, **kwargs): return AOTModule() -def aot_module_simplified( +def prepare_aot_module_simplified( mod: nn.Module, args, fw_compiler: AOTDispatchCompiler, - bw_compiler: Optional[AOTDispatchCompiler] = None, - partition_fn: Callable = default_partition, - decompositions: Optional[dict] = None, - keep_inference_input_mutations=False, - inference_compiler: Optional[AOTDispatchCompiler] = None, - cudagraphs: Optional[BoxedBool] = None, - boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, - ignore_shape_env: bool = False, -) -> nn.Module: - """ - This is the simplified or low overhead version of aot_module. For frontends - like TorchDynamo, the input functions/modules to AOT are static and have - unpacked inputs/outputs. This gives us an opportunity to remove the - (1) pytree overhead to parse inputs/outputs, - (2) AOT Autograd cache, - (3) Reading of params/buffers in every forward call - - :func:`aot_module_simplified` removes these overheads. - """ + bw_compiler: AOTDispatchCompiler, + partition_fn: Callable, + decompositions: dict, + keep_inference_input_mutations, + inference_compiler: AOTDispatchCompiler, + boxed_forward_device_index: BoxedDeviceIndex, + ignore_shape_env: bool, +): + # TODO: There's something a bit suspicious here; typically simplified + # module shouldn't actually have any parameters... params = { **dict(mod.named_parameters(remove_duplicate=False)), **dict(mod.named_buffers(remove_duplicate=False)), @@ -885,14 +876,6 @@ def aot_module_simplified( params_flat = list(params_flat) params_len = len(params_flat) - if cudagraphs is None: - cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) - - if bw_compiler is None: - bw_compiler = fw_compiler - if inference_compiler is None: - inference_compiler = fw_compiler - full_args = [] # First, the params full_args.extend(params_flat) @@ -940,31 +923,91 @@ def aot_module_simplified( fake_flat_args = process_inputs( full_args, aot_config, fake_mode, shape_env, ignore_shape_env ) + functional_call = create_functional_call(mod, params_spec, params_len) + + return ( + functional_call, + params_flat, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + + +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: AOTDispatchCompiler, + bw_compiler: Optional[AOTDispatchCompiler] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[dict] = None, + keep_inference_input_mutations=False, + inference_compiler: Optional[AOTDispatchCompiler] = None, + # TODO: This doesn't seem to be used in any nontrivial way, check if it's + # actually needed + cudagraphs: Optional[BoxedBool] = None, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + ignore_shape_env: bool = False, +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + + if cudagraphs is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler with contextlib.ExitStack() as stack: - while True: - # We only care if the forward will return an OutputCode. - if isinstance(fw_compiler, SerializableAOTDispatchCompiler): - local = should_use_local_autograd_cache() - remote = should_use_remote_autograd_cache() - if local or remote: - set_feature_use("aot_autograd_remote_cache", remote) - compiled_fn = AOTAutogradCache.try_load( - mod, - fake_flat_args, - aot_config, - cudagraphs, - boxed_forward_device_index, - local, - remote, - ) - if compiled_fn is not None: - break + ( + functional_call, + params_flat, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) = prepare_aot_module_simplified( + mod, + args, + fw_compiler, + bw_compiler, + partition_fn, + decompositions, + keep_inference_input_mutations, + inference_compiler, + boxed_forward_device_index, + ignore_shape_env, + ) - functional_call = create_functional_call(mod, params_spec, params_len) + compiled_fn = None + + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.try_load( + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + if compiled_fn is None: stack.enter_context(compiled_autograd._disable()) - aot_state = create_aot_state( stack, functional_call, @@ -975,36 +1018,30 @@ def aot_module_simplified( ) aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) - break if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes # the inputs so that they can be freed before the end of this scope. # For overhead reasons, this is not the default wrapper, see comment: # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 - def boxed_forward(runtime_args: list[Any]): + def forward(runtime_args: list[Any]): flat_args = [] flat_args.extend(params_flat) flat_args.extend(runtime_args) runtime_args.clear() return compiled_fn(flat_args) - # Just for convenience - boxed_forward.zero_grad = mod.zero_grad - boxed_forward.named_parameters = mod.named_parameters - boxed_forward.named_buffers = mod.named_buffers - return boxed_forward - - # TODO: There is something deeply wrong here; compiled_fn running with - # the boxed calling convention, but aot_module_simplified somehow - # historically returned a function that was not the boxed calling - # convention. This should get fixed... - # NB: GraphModule/nn.Module rely on the non-boxed calling convention here - def forward(*runtime_args: tuple[Any]): - full_args = [] - full_args.extend(params_flat) - full_args.extend(runtime_args) - return compiled_fn(full_args) + else: + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + # NB: GraphModule/nn.Module rely on the non-boxed calling convention here + def forward(*runtime_args: tuple[Any]): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) # Just for convenience forward.zero_grad = mod.zero_grad From 0a9d450168ce58b2bb7f2cedc27a61012123564f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 15 Jul 2025 13:53:25 -0700 Subject: [PATCH 094/457] [DTensor] implement histc (#158298) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158298 Approved by: https://github.com/zpcore, https://github.com/XilunWu --- test/distributed/tensor/test_math_ops.py | 36 ++++++++++++++++++++++ torch/distributed/tensor/_ops/_math_ops.py | 24 +++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 79a7112e0f190..e13e0c0266b8b 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -811,6 +811,42 @@ def apply_rotary_emb(xq, freqs_cis): self.assertEqual(dtensor_grad, xq.grad) + @with_comms + def test_histc(self): + # TODO - nicer to use parametrize here so its easy to run one sub-test by name, + # but its too slow (10sec per process-group init) -> switch to MultiProcessContinuousTest + device_mesh = self.build_device_mesh() + comm_mode = CommDebugMode() + tensor = torch.randn(12, 8, 8, requires_grad=True) + for min_max_specified in (True, False): + for placement in [Shard(0), Shard(1), Shard(2), Replicate()]: + min_ = tensor.min().item() + max_ = tensor.max().item() + global_bins = ( + tensor.histc(min=min_, max=max_) + if min_max_specified + else tensor.histc() + ) + + dtensor = distribute_tensor(tensor, device_mesh, (placement,)) + with comm_mode: + out_dt = ( + dtensor.histc(min=min_, max=max_) + if min_max_specified + else dtensor.histc() + ) + + if placement.is_shard() and not min_max_specified: + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 + ) + else: + self.assertEqual(comm_mode.get_total_counts(), 0) + + out_full = out_dt.full_tensor() + self.assertEqual(global_bins, out_full) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 49df9c63a9ead..c1bb96d9c319b 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -1090,3 +1090,27 @@ def topk_strategy(op_schema: OpSchema) -> OpStrategy: return expand_to_full_mesh_op_strategy( input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 ) + + +@register_op_strategy( + [aten.histc.default], + # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3 + schema_info=RuntimeSchemaInfo(2), +) +def histc_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies: list[PlacementList] = [] + single_mesh_dim_strategies.append([Replicate(), Replicate()]) + + # histc can support sharded input and partial output on any input dim, provided the min and max + # values are user-specified. If not user-specified, the true min and max of the data in each local + # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to + # an incorrect final result + if len(op_schema.args_schema) == 4: + for dim in range(input_strategy.ndim): + dim_shardings: PlacementList = [Partial(), Shard(dim)] + single_mesh_dim_strategies.append(dim_shardings) + + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies + ) From e92e3eaf4eb815ea28db9a5af9d9ee48c3f7be3f Mon Sep 17 00:00:00 2001 From: Denghui Dong Date: Wed, 16 Jul 2025 04:10:46 +0000 Subject: [PATCH 095/457] [Profiler] the doc of _ExperimentalConfig is incorrectly truncated by commas (#156586) Hi team, Please help review this trivial fix. Without this change: ``` python >>> import torch >>> print(torch._C._profiler._ExperimentalConfig.__init__.__doc__) __init__(self: torch._C._profiler._ExperimentalConfig, profiler_metrics: list[str] = [], profiler_measure_per_kernel: bool = False, verbose: bool = False, performance_events: list[str] = [], enable_cuda_sync_events: bool = False, adjust_profiler_step: bool = False, disable_external_correlation: bool = False, profile_all_threads: bool = False, capture_overload_names: bool = False) -> None capture_overload_names (bool) : whether to include ATen overload names in the profile ``` With this change: ```python >>> import torch >>> print(torch._C._profiler._ExperimentalConfig.__init__.__doc__) __init__(self: torch._C._profiler._ExperimentalConfig, profiler_metrics: list[str] = [], profiler_measure_per_kernel: bool = False, verbose: bool = False, performance_events: list[str] = [], enable_cuda_sync_events: bool = False, adjust_profiler_step: bool = False, disable_external_correlation: bool = False, profile_all_threads: bool = False, capture_overload_names: bool = False) -> None An experimental config for Kineto features. Please note thatbackward compatibility is not guaranteed. profiler_metrics : a list of CUPTI profiler metrics used to measure GPU performance events. If this list contains values Kineto runs in CUPTI profiler mode profiler_measure_per_kernel (bool) : whether to profile metrics per kernel or for the entire measurement duration. verbose (bool) : whether the trace file has `Call stack` field or not. performance_events : a list of profiler events to be used for measurement. enable_cuda_sync_events : for CUDA profiling mode, enable adding CUDA synchronization events that expose CUDA device, stream and event synchronization activities. This feature is new and currently disabled by default. adjust_profiler_step (bool) : whether to adjust the profiler step to match the parent python event duration. This feature is new and currently disabled by default. disable_external_correlation (bool) : whether to disable external correlation profile_all_threads (bool) : whether to profile all threads capture_overload_names (bool) : whether to include ATen overload names in the profile ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156586 Approved by: https://github.com/sraikund16, https://github.com/cyyever --- torch/csrc/profiler/python/init.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 92f2f39a5da23..062f87a465ccb 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -356,10 +356,10 @@ void initPythonBindings(PyObject* module) { " that expose CUDA device, stream and event synchronization activities. This feature is new\n" " and currently disabled by default.\n" " adjust_profiler_step (bool) : whether to adjust the profiler step to\n" - " match the parent python event duration. This feature is new and currently disabled by default.\n", - " disable_external_correlation (bool) : whether to disable external correlation\n", - " profile_all_threads (bool) : whether to profile all threads\n", - " capture_overload_names (bool) : whether to include ATen overload names in the profile\n", + " match the parent python event duration. This feature is new and currently disabled by default.\n" + " disable_external_correlation (bool) : whether to disable external correlation\n" + " profile_all_threads (bool) : whether to profile all threads\n" + " capture_overload_names (bool) : whether to include ATen overload names in the profile\n" " custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, From 61a7b09ef39907b6c4b47445f965c4c77cc6c2ed Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 15 Jul 2025 13:05:17 +0800 Subject: [PATCH 096/457] [BE][Easy] split build system `requirements.txt` to a separate file (#158111) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158111 Approved by: https://github.com/ezyang --- .ci/manywheel/build_common.sh | 3 +-- .ci/manywheel/build_libtorch.sh | 3 +-- .ci/pytorch/run_tests.sh | 5 +++-- .ci/wheel/build_wheel.sh | 3 ++- .github/requirements-gha-cache.txt | 3 ++- .github/workflows/lint.yml | 2 ++ Dockerfile | 2 +- README.md | 4 ++-- requirements-build.txt | 10 ++++++++++ requirements.txt | 10 +--------- 10 files changed, 25 insertions(+), 20 deletions(-) create mode 100644 requirements-build.txt diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 6437cf9f0d488..49549c9f2994e 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -97,8 +97,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" -retry pip install -q "setuptools>=70.1.0" packaging -retry pip install -qU cmake ninja +retry pip install -qUr requirements-build.txt python setup.py clean retry pip install -qr requirements.txt case ${DESIRED_PYTHON} in diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index 30a723cb10958..4de775b1823ca 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -92,8 +92,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" -retry pip install -q "setuptools>=70.1.0" packaging -retry pip install -qU cmake ninja +retry pip install -qUr requirements-build.txt python setup.py clean retry pip install -qr requirements.txt retry pip install -q numpy==2.0.1 diff --git a/.ci/pytorch/run_tests.sh b/.ci/pytorch/run_tests.sh index 34ee40d7bcd0f..f5ed90deef249 100755 --- a/.ci/pytorch/run_tests.sh +++ b/.ci/pytorch/run_tests.sh @@ -74,12 +74,13 @@ else fi # Environment initialization +retry pip install -qUr requirements-build.txt if [[ "$(uname)" == Darwin ]]; then # Install the testing dependencies - retry pip install -q future hypothesis ${NUMPY_PACKAGE} ${PROTOBUF_PACKAGE} pytest setuptools six typing_extensions pyyaml + retry pip install -q future hypothesis ${NUMPY_PACKAGE} ${PROTOBUF_PACKAGE} pytest else retry pip install -qr requirements.txt || true - retry pip install -q hypothesis protobuf pytest setuptools || true + retry pip install -q hypothesis protobuf pytest || true numpy_ver=1.15 case "$(python --version 2>&1)" in *2* | *3.5* | *3.6*) diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 6070e967ef821..878d6595c84c0 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -184,7 +184,8 @@ tmp_env_name="wheel_py$python_nodot" conda create ${EXTRA_CONDA_INSTALL_FLAGS} -yn "$tmp_env_name" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} source activate "$tmp_env_name" -pip install "numpy=${NUMPY_PINNED_VERSION}" "pyyaml${PYYAML_PINNED_VERSION}" requests ninja "setuptools${SETUPTOOLS_PINNED_VERSION}" typing_extensions +retry pip install -r "${pytorch_rootdir}/requirements-build.txt" +pip install "numpy=${NUMPY_PINNED_VERSION}" "pyyaml${PYYAML_PINNED_VERSION}" requests ninja "setuptools${SETUPTOOLS_PINNED_VERSION}" typing-extensions retry pip install -r "${pytorch_rootdir}/requirements.txt" || true retry brew install libomp diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 5e2819c8a8362..5c691e4bf9b31 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -1,5 +1,6 @@ # This file is to cache other dependencies not specified elsewhere in: -# requirement.txt +# requirements.txt +# requirements-build.txt # docs/requirements.txt # docs/cpp/requirements.txt # functorch/docs/requirements.txt diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0fca34048196a..66cd5f653446b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -260,6 +260,7 @@ jobs: check-latest: false cache: pip cache-dependency-path: | + **/requirements-build.txt **/requirements.txt - name: Setup Min Python version if: matrix.test_type != 'older_python_version' @@ -270,6 +271,7 @@ jobs: check-latest: false cache: pip cache-dependency-path: | + **/requirements-build.txt **/requirements.txt - name: Install torch if: matrix.test_type == 'with_torch' diff --git a/Dockerfile b/Dockerfile index 9f23712af2b86..63b8c5bcb47aa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN case ${TARGETPLATFORM} in \ *) MINICONDA_ARCH=x86_64 ;; \ esac && \ curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" -COPY requirements.txt . +COPY requirements.txt requirements-build.txt . # Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 RUN chmod +x ~/miniconda.sh && \ bash ~/miniconda.sh -b -p /opt/conda && \ diff --git a/README.md b/README.md index e566f1356d9cc..6d995f130e70b 100644 --- a/README.md +++ b/README.md @@ -294,14 +294,14 @@ Install PyTorch ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" -python -m pip install -r requirements.txt +python -m pip install -r requirements-build.txt python -m pip install --no-build-isolation -v -e . ``` **On macOS** ```bash -python -m pip install -r requirements.txt +python -m pip install -r requirements-build.txt python -m pip install --no-build-isolation -v -e . ``` diff --git a/requirements-build.txt b/requirements-build.txt new file mode 100644 index 0000000000000..be19d987f73db --- /dev/null +++ b/requirements-build.txt @@ -0,0 +1,10 @@ +# Build System requirements +setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 +cmake>=3.27 +ninja +numpy +packaging +pyyaml +requests +six # dependency chain: NNPACK -> PeachPy -> six +typing-extensions>=4.10.0 diff --git a/requirements.txt b/requirements.txt index 4526f303c046b..2f585def9f19f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,7 @@ # Python dependencies required for development # Build System requirements -setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 -cmake>=3.27 -ninja -numpy -packaging -pyyaml -requests -six # dependency chain: NNPACK -> PeachPy -> six -typing-extensions>=4.10.0 +--requirement requirements-build.txt # Install / Development extra requirements build[uv] # for building sdist and wheel From 5484890539823d9867c74209588abe095c9232a1 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 15 Jul 2025 16:22:32 -0700 Subject: [PATCH 097/457] Add better typing to avaialbe kernel options for flex attention (#158383) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158383 Approved by: https://github.com/joydddd, https://github.com/BoyuanFeng --- torch/nn/attention/flex_attention.py | 138 ++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 3 deletions(-) diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index ce592c1ed342f..160dba68f0ccf 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -9,10 +9,22 @@ import operator import warnings from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, Union import torch from torch import Tensor + + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch._higher_order_ops.utils import _set_compilation_env from torch._prims_common import DeviceLikeType @@ -27,6 +39,7 @@ __all__ = [ "BlockMask", "flex_attention", + "FlexKernelOptions", "create_block_mask", "create_mask", "create_nested_block_mask", @@ -39,6 +52,123 @@ _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +class FlexKernelOptions(TypedDict, total=False): + """Options for controlling the behavior of FlexAttention kernels. + + These options are passed to the underlying Triton kernels to control performance + and numerical behavior. Most users will not need to specify these options as the + default autotuning provides good performance. + + The options can be prefixed with 'fwd_' or 'bwd_' to apply only to forward or + backward pass respectively. For example: 'fwd_BLOCK_M' and 'bwd_BLOCK_M1'. + + Note: + We currently do not provide any backward compatibility guarantees for these options. + That being said most of these have remained pretty stable since their introduction. But + We do not consider this part of the public API just yet. We think that some documentation + Is better than secret hidden flags, but we may change these options in the future. + + Example Usage: + .. code-block:: python + + # Using dictionary (backward compatible) + kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True} + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Using TypedDict (recommended for type safety) + from torch.nn.attention.flex_attention import FlexKernelOptions + + kernel_opts: FlexKernelOptions = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "PRESCALE_QK": True, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Forward/backward specific options + kernel_opts: FlexKernelOptions = { + "fwd_BLOCK_M": 64, + "bwd_BLOCK_M1": 32, + "PRESCALE_QK": False, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + """ + + # Performance tuning options + num_warps: NotRequired[int] + """Number of warps to use in the CUDA kernel. Higher values may improve performance + but increase register pressure. Default is determined by autotuning.""" + + num_stages: NotRequired[int] + """Number of pipeline stages in the CUDA kernel. Higher values may improve performance + but increase shared memory usage. Default is determined by autotuning.""" + + BLOCK_M: NotRequired[int] + """Thread block size for the sequence length dimension of Q in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + BLOCK_N: NotRequired[int] + """Thread block size for the sequence length dimension of K/V in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + # Backward-specific block sizes (when prefixed with 'bwd_') + BLOCK_M1: NotRequired[int] + """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. + Default is determined by autotuning.""" + + BLOCK_N1: NotRequired[int] + """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. + Default is determined by autotuning.""" + + BLOCK_M2: NotRequired[int] + """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. + Default is determined by autotuning.""" + + BLOCK_N2: NotRequired[int] + """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. + Default is determined by autotuning.""" + + PRESCALE_QK: NotRequired[bool] + """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but + may have more numerical error. Default: False.""" + + ROWS_GUARANTEED_SAFE: NotRequired[bool] + """If True, guarantees that at least one value in each row is not masked out. + Allows skipping safety checks for better performance. Only set this if you are certain + your mask guarantees this property. For example, causal attention is guaranteed safe + because each query has at least 1 key-value to attend to. Default: False.""" + + BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] + """If True, guarantees that all blocks in the mask are contiguous. + Allows optimizing block traversal. For example, causal masks would satisfy this, + but prefix_lm + sliding window would not. Default: False.""" + + WRITE_DQ: NotRequired[bool] + """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. + Setting this to False will force this to happen in the DK loop which depending on your + specific score_mod and mask_mod might be faster. Default: True.""" + + FORCE_USE_FLEX_ATTENTION: NotRequired[bool] + """If True, forces the use of the flex attention kernel instead of potentially using + the more optimized flex-decoding kernel for short sequences. This can be a helpful + option for debugging. Default: False.""" + + USE_TMA: NotRequired[bool] + """Whether to use Tensor Memory Accelerator (TMA) on supported hardware. + This is experimental and may not work on all hardware, currently specific + to NVIDIA GPUs Hopper+. Default: False.""" + + # ROCm-specific options + kpack: NotRequired[int] + """ROCm-specific kernel packing parameter.""" + + matrix_instr_nonkdim: NotRequired[int] + """ROCm-specific matrix instruction non-K dimension.""" + + waves_per_eu: NotRequired[int] + """ROCm-specific waves per execution unit.""" + + class _ModificationType(Enum): """Enum for the type of modification function. - SCORE_MOD: score_mod function which accepts a score as the first argument @@ -1244,7 +1374,7 @@ def flex_attention( scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, - kernel_options: Optional[dict[str, Any]] = None, + kernel_options: Optional[FlexKernelOptions] = None, ) -> Union[Tensor, tuple[Tensor, Tensor]]: r"""This function implements scaled dot product attention with an arbitrary attention score modification function. @@ -1280,7 +1410,9 @@ def score_mod( scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. - kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels. + kernel_options (Optional[FlexKernelOptions]): + Options to control the behavior of the underlying Triton kernels. + See :class:`FlexKernelOptions` for available options and usage examples. Returns: output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. From fedbd1a48e1e474cf9da5637aae89b5bc4c20626 Mon Sep 17 00:00:00 2001 From: saienduri Date: Wed, 16 Jul 2025 06:09:37 +0000 Subject: [PATCH 098/457] Enable ROCm 7.0 Alpha docker builds for PyTorch CI (#158390) This PR adds ROCm 7.0 alpha docker builds to start testing latest ROCm in PyTorch CI and enable new MI350x hardware. Highlights: * Stop building `pytorch-linux-jammy-rocm-n-1-py3` docker images, as they're not currently used in any CI workflows * Add `pytorch-linux-noble-rocm-alpha-py3` docker images that will use ROCm alpha (newer than latest official release) builds Pull Request resolved: https://github.com/pytorch/pytorch/pull/158390 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily --- .ci/docker/build.sh | 21 +++++++++++---------- .ci/docker/common/install_rocm.sh | 13 +++++++++++-- .github/workflows/docker-builds.yml | 2 +- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 075b5e80209fd..d6cba6659db7a 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -231,11 +231,15 @@ case "$tag" in VISION=yes TRITON=yes ;; - pytorch-linux-jammy-rocm-n-1-py3) - ANACONDA_PYTHON_VERSION=3.10 + pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) + if [[ $tag =~ "jammy" ]]; then + ANACONDA_PYTHON_VERSION=3.10 + else + ANACONDA_PYTHON_VERSION=3.12 + fi GCC_VERSION=11 VISION=yes - ROCM_VERSION=6.3 + ROCM_VERSION=6.4 NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes @@ -243,21 +247,18 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) - if [[ $tag =~ "jammy" ]]; then - ANACONDA_PYTHON_VERSION=3.10 - else - ANACONDA_PYTHON_VERSION=3.12 - fi + pytorch-linux-noble-rocm-alpha-py3) + ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 VISION=yes - ROCM_VERSION=6.4 + ROCM_VERSION=7.0 NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} INDUCTOR_BENCHMARKS=yes + PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950" ;; pytorch-linux-jammy-xpu-2025.0-py3) ANACONDA_PYTHON_VERSION=3.9 diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 39a3f0eaf1c42..2b2bb47ea0946 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -33,13 +33,22 @@ EOF ROCM_VERSION="${ROCM_VERSION}.1" fi + # Default url values + rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu" + + # Special case for ROCM_VERSION == 7.0 + if [[ $(ver "$ROCM_VERSION") -eq $(ver 7.0) ]]; then + rocm_baseurl="https://repo.radeon.com/rocm/apt/7.0_alpha2" + amdgpu_baseurl="https://repo.radeon.com/amdgpu/30.10_alpha2/ubuntu" + fi + # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` - echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list + echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list # Add rocm repository wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - - local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" echo "deb [arch=amd64] ${rocm_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/rocm.list apt-get update --allow-insecure-repositories diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 43843751eb8fd..4678779443b98 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -62,9 +62,9 @@ jobs: pytorch-linux-jammy-py3.11-clang12, pytorch-linux-jammy-py3.12-clang12, pytorch-linux-jammy-py3.13-clang12, - pytorch-linux-jammy-rocm-n-1-py3, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, + pytorch-linux-noble-rocm-alpha-py3, pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, pytorch-linux-jammy-py3.9-gcc11, pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, From 59f9b25f3cfc635053843372ea29ff4bf754da3f Mon Sep 17 00:00:00 2001 From: Kaichao You Date: Wed, 16 Jul 2025 07:12:32 +0000 Subject: [PATCH 099/457] [cuda][cupy] Improve cupy device placement when device is provided (#158320) This is an improvement over https://github.com/pytorch/pytorch/pull/132595 . That PR improves the case where `device` is not given. This PR tries to improve the case where `device` is given but the first step of auto-infer device from `cudaPointerGetAttributes` can be wrong (undesired). See https://github.com/pytorch/pytorch/issues/158316 for more details on when this can happen. I think this is a reasonable improvement, as people expect `torch.as_tensor` + cupy should be zero-copy as much as possible. However, it does change some behaviors, because previously it might incur a device-to-device copy. I will leave it to pytorch developers to see if the improvement is worthwhile. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158320 Approved by: https://github.com/ezyang --- test/distributed/test_cupy_as_tensor.py | 104 ++++++++++++++++++++++++ torch/_torch_docs.py | 3 +- torch/csrc/utils/tensor_new.cpp | 2 +- torch/csrc/utils/tensor_numpy.cpp | 12 ++- torch/csrc/utils/tensor_numpy.h | 4 +- 5 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 test/distributed/test_cupy_as_tensor.py diff --git a/test/distributed/test_cupy_as_tensor.py b/test/distributed/test_cupy_as_tensor.py new file mode 100644 index 0000000000000..e5b13adf32dde --- /dev/null +++ b/test/distributed/test_cupy_as_tensor.py @@ -0,0 +1,104 @@ +# Owner(s): ["oncall: distributed"] + +# To run: +# python test/distributed/test_cupy_as_tensor.py + +import os +from dataclasses import dataclass + +import torch +from torch.multiprocessing.reductions import reduce_tensor +from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_utils import ( + requires_cuda_p2p_access, + run_tests, + skipIfRocm, +) + + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# So that tests are written in device-agnostic way +device_type = "cuda" +device_module = torch.get_device_module(device_type) + + +@dataclass +class CupyWrapper: + data_ptr: int + size_in_bytes: int + + @property + def __cuda_array_interface__(self): + return { + "shape": (self.size_in_bytes,), + "typestr": "|u1", + "data": (self.data_ptr, False), + "version": 3, + } + + +def from_buffer( + data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype +) -> torch.Tensor: + data = torch.as_tensor(CupyWrapper(data_ptr, size_in_bytes), device=device).view( + dtype + ) + assert data.data_ptr() == data_ptr + return data + + +@requires_cuda_p2p_access() +class CupyAsTensorTest(MultiProcContinousTest): + @classmethod + def backend_str(cls): + return "gloo" + + def _init_device(self) -> None: + # init and pin the process to the device + device_module.set_device(self.device) + torch.empty(1, device=self.device) + + @property + def device(self) -> torch.device: + return torch.device(device_type, self.rank) + + @skipIfRocm + def test_cupy_as_tensor(self) -> None: + """ + Test that torch.as_tensor works for cupy array interface + with zero-copy when the pointer is p2p-shared across processes. + """ + self._init_device() + + tensor: torch.Tensor + if self.rank == 1: + # it seems only error from rank non-zero will be caught by this test + tensor = torch.randn(2333, device=self.device) + tensor_meta = reduce_tensor(tensor) + torch.distributed.broadcast_object_list([tensor_meta], src=1) + else: + recv_list = [None] + torch.distributed.broadcast_object_list(recv_list, src=1) + tensor_meta = recv_list[0] + func, args = tensor_meta + args = list(args) + args[6] = self.rank + ipc_tensor = func(*args) + tensor = from_buffer( + ipc_tensor.data_ptr(), + ipc_tensor.numel() * ipc_tensor.element_size(), + self.device, + ipc_tensor.dtype, + ) + + torch.distributed.barrier() + if self.rank == 1: + tensor.fill_(1) + device_module.synchronize() + torch.distributed.barrier() + assert tensor.allclose(tensor, 1) + torch.distributed.barrier() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0766bf7742864..958b040f7f3ed 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1006,7 +1006,8 @@ def merge_dicts(*dicts): tensor is constructed using :func:`torch.from_numpy`. If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless -specifically overwritten by :attr:`device` or a default device. +specifically overwritten by :attr:`device` or a default device. The device of the CuPy array is inferred from the +pointer of the array using `cudaPointerGetAttributes` unless :attr:`device` is provided. .. seealso:: diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 45f58cde9a659..35511300f703e 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -304,7 +304,7 @@ Tensor internal_new_from_data( TORCH_CHECK( !pin_memory, "Can't pin tensor constructed from __cuda_array_interface__"); - auto tensor = tensor_from_cuda_array_interface(data); + auto tensor = tensor_from_cuda_array_interface(data, device_opt); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index c8548884692fd..2d9651748c315 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -27,7 +27,9 @@ bool is_numpy_int(PyObject* obj) { bool is_numpy_scalar(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { +at::Tensor tensor_from_cuda_array_interface( + PyObject* obj, + std::optional device_opt) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -380,7 +382,9 @@ bool is_numpy_scalar(PyObject* obj) { PyArray_IsScalar(obj, ComplexFloating)); } -at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { +at::Tensor tensor_from_cuda_array_interface( + PyObject* obj, + std::optional device_opt) { if (!is_numpy_available()) { throw std::runtime_error("Numpy is not available"); } @@ -489,7 +493,9 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { // ref: // https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html#cuda-array-interface-version-3 if (data_ptr != nullptr) { - return {}; + // if device_opt is provided and not nullopt, use it, otherwise infer from + // cudaPointerGetAttributes later in from_blob + return device_opt; } else { const auto current_device = at::detail::getCUDAHooks().getCurrentDevice(); return Device( diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index a7c1d8cf5476e..5f93cbb089c21 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -22,7 +22,9 @@ TORCH_API bool is_numpy_bool(PyObject* obj); TORCH_API bool is_numpy_scalar(PyObject* obj); void warn_numpy_not_writeable(); -at::Tensor tensor_from_cuda_array_interface(PyObject* obj); +at::Tensor tensor_from_cuda_array_interface( + PyObject* obj, + std::optional device_opt = std::nullopt); void validate_numpy_for_dlpack_deleter_bug(); bool is_numpy_dlpack_deleter_bugged(); From 555f3562541992b66a550eca8e8740884b1247f8 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Thu, 3 Jul 2025 19:53:08 +0800 Subject: [PATCH 100/457] [Easy] Show some clear error when torch.ops.load_library fails. (#157524) **Background**: ```Shell torch 2.5.1+cpu torchvision 0.20.1 ``` ```Python import torch import torchvision Traceback (most recent call last): File "", line 1, in File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torchvision/__init__.py", line 10, in from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torchvision/_meta_registrations.py", line 164, in def meta_nms(dets, scores, iou_threshold): File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/library.py", line 795, in register use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/library.py", line 184, in _register_fake handle = entry.fake_impl.register(func_to_register, source) File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 31, in register if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): RuntimeError: operator torchvision::nms does not exist ``` **Cause**: ``` torchvision's .so file lacks some symbol definitions, because these symbols come from CUDA, but the current environment does not have CUDA and GPU. The above error message is very confusing. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157524 Approved by: https://github.com/ezyang --- torch/_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_ops.py b/torch/_ops.py index 600f6d9e1ada1..eeeb1dfc71130 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1475,7 +1475,10 @@ def load_library(self, path): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom # operators with the JIT. - ctypes.CDLL(path) + try: + ctypes.CDLL(path) + except Exception as e: + raise RuntimeError(f"Could not load this library: {path}") from e self.loaded_libraries.add(path) From ddf502c988133835a89959bef945bf9c5f06b428 Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Wed, 16 Jul 2025 07:55:04 +0000 Subject: [PATCH 101/457] [AOTI] add -lstdc++ into aoti link cmd for Meta internal (#158325) Differential Revision: D78123716 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158325 Approved by: https://github.com/desertfire --- torch/_inductor/cpp_builder.py | 36 ++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 9d0c7c76dfeac..975045c529ded 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1074,6 +1074,19 @@ def _get_openmp_args( return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args +def _get_libstdcxx_args(cpp_compiler: str) -> tuple[list[str], list[str]]: + """ + For fbcode, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. + """ + lib_dir_paths: list[str] = [] + libs: list[str] = [] + if config.is_fbcode(): + lib_dir_paths = [sysconfig.get_config_var("LIBDIR")] + libs.append("stdc++") + + return lib_dir_paths, libs + + def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]: macros = [] if use_mmap_weights: @@ -1089,6 +1102,15 @@ def get_cpp_torch_options( use_relative_path: bool, use_mmap_weights: bool, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of torch related build options. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + 6. Return the build args + """ definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1125,6 +1147,11 @@ def get_cpp_torch_options( omp_passthrough_args, ) = _get_openmp_args(cpp_compiler) + ( + stdcxx_lib_dir_paths, + stdcxx_libs, + ) = _get_libstdcxx_args(cpp_compiler) + fb_macro_passthrough_args = _use_fb_internal_macros() mmap_self_macros = get_mmap_self_macro(use_mmap_weights) @@ -1144,8 +1171,13 @@ def get_cpp_torch_options( ) cflags = sys_libs_cflags + omp_cflags ldflags = omp_ldflags - libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths - libraries = torch_libraries + omp_lib + libraries_dirs = ( + python_libraries_dirs + + torch_libraries_dirs + + omp_lib_dir_paths + + stdcxx_lib_dir_paths + ) + libraries = torch_libraries + omp_lib + stdcxx_libs passthrough_args = ( sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args ) From fb9a5d248f36ddce041025c8fc5be0d8bee454b0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 16 Jul 2025 08:11:50 +0000 Subject: [PATCH 102/457] Fix torch._numpy to match NumPy when empty ellipsis causes advanced indexing separation (#158297) Fixes #141563 In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case. This difference in behavior leads to a bug when using torch.compile: ```python >>> import numpy as np >>> f = lambda x: x[:,(0,1),...,(0,1)].shape >>> a = np.ones((3, 4, 5)) >>> f(a) (2, 3) >>> torch.compile(f)(a) (3, 2) ``` Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified. Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples: - #71673 - #107699 - #158125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297 Approved by: https://github.com/soumith --- test/torch_np/test_indexing.py | 71 +++++++++++++++++++++++ torch/_numpy/_ndarray.py | 101 ++++++++++++++++++++++++++++++--- 2 files changed, 165 insertions(+), 7 deletions(-) diff --git a/test/torch_np/test_indexing.py b/test/torch_np/test_indexing.py index eac68246bd5a3..084cc2f73e8b0 100644 --- a/test/torch_np/test_indexing.py +++ b/test/torch_np/test_indexing.py @@ -412,6 +412,77 @@ def test_special_index_types(self): self._test_cases(cases + numpy_torch_cases, "Special index types") + def test_ellipsis(self): + """Tests containing ellipsis.""" + cases = [ + # Ellipsis + Basic indexing + { + "shape": (3, 4, 5), + "index": (slice(None), 0, ..., slice(None)), + "name": "empty ellipsis without advanced indexing", + }, + { + "shape": (3, 4, 5), + "index": (slice(None), ..., 0), + "name": "non-empty ellipsis without advanced indexing", + }, + # Ellipsis + Advanced indexing without separation + { + "shape": (3, 4, 5), + "index": (slice(None), ..., slice(None), (0, 1)), + "name": "empty ellipsis without separation", + }, + { + "shape": (3, 4, 5), + "index": (slice(None), ..., (0, 1)), + "name": "non-empty ellipsis without separation", + }, + # Ellipsis + Advanced indexing with separation + { + "shape": (3, 4, 5), + "index": (slice(None), (0, 1), ..., (0, 1)), + "name": "empty ellipsis separation", + }, + { + "shape": (1, 3, 4, 5), + "index": (slice(None), (0, 1), ..., (0, 1)), + "name": "non-empty ellipsis separation", + }, + { + "shape": (4, 3, 5), + "index": (slice(None), ((0,), (1,)), ..., (0, 1)), + "name": "empty ellipsis separation with 2-depth int sequence", + }, + { + "shape": (4, 3, 5, 6), + "index": (slice(None), ((0,), (1,)), ..., (0, 1), slice(None)), + "name": "empty ellipsis separation with 2-depth int sequence and end slice", + }, + { + "shape": (4, 3, 5, 6), + "index": (slice(None), ((0,), (1,)), ..., (0, 1), (((0, 1), (1, 2)),)), + "name": "empty ellipsis separation with 2 and 3-depth int sequence", + }, + # Ellipsis + Boolean masks in advanced indexing with separation + { + "shape": (3, 4, 5), + "index": (slice(None), True, True, True, ..., 0, 0), + "name": "empty ellipsis separation with 0-dim boolean masks", + }, + { + "shape": (4, 3, 5), + "index": (slice(None), (True, True, False), ..., (0, 1)), + "name": "empty ellipsis separation with 1-dim boolean masks", + }, + # TODO(manuelcandales) Fix issue #71673 and enable this case + # { + # "shape": (1, 2, 2, 4, 5), + # "index": (slice(None), ((True, False), (True, True)), (0, 1, 2), ..., (0,)), + # "name": "empty ellipsis separation with 2-dim boolean masks", + # }, + ] + self._test_cases(cases, "Ellipsis and advanced indexing separation") + if __name__ == "__main__": run_tests() diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index 05e82300145d3..f192a39dd0296 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -169,17 +169,22 @@ def _upcast_int_indices(index): return index +def _has_advanced_indexing(index): + """Check if there's any advanced indexing""" + return any( + isinstance(idx, (Sequence, bool)) + or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0)) + for idx in index + ) + + def _numpy_compatible_indexing(index): """Convert scalar indices to lists when advanced indexing is present for NumPy compatibility.""" if not isinstance(index, tuple): index = (index,) # Check if there's any advanced indexing (sequences, booleans, or tensors) - has_advanced = any( - isinstance(idx, (Sequence, bool)) - or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0)) - for idx in index - ) + has_advanced = _has_advanced_indexing(index) if not has_advanced: return index @@ -206,6 +211,84 @@ def _numpy_compatible_indexing(index): return tuple(converted) +def _get_bool_depth(s): + """Returns the depth of a boolean sequence/tensor""" + if isinstance(s, bool): + return True, 0 + if isinstance(s, torch.Tensor) and s.dtype == torch.bool: + return True, s.ndim + if not (isinstance(s, Sequence) and s and s[0] != s): + return False, 0 + is_bool, depth = _get_bool_depth(s[0]) + return is_bool, depth + 1 + + +def _numpy_empty_ellipsis_patch(index, tensor_ndim): + """ + Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions. + + In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array, + it still acts as a separator between advanced indices. PyTorch doesn't have this behavior. + + This function detects when we have: + 1. Advanced indexing on both sides of an ellipsis + 2. The ellipsis doesn't actually match any dimensions + """ + if not isinstance(index, tuple): + index = (index,) + + # Find ellipsis position + ellipsis_pos = None + for i, idx in enumerate(index): + if idx is Ellipsis: + ellipsis_pos = i + break + + # If no ellipsis, no patch needed + if ellipsis_pos is None: + return index, lambda x: x, lambda x: x + + # Count non-ellipsis dimensions consumed by the index + consumed_dims = 0 + for idx in index: + is_bool, depth = _get_bool_depth(idx) + if is_bool: + consumed_dims += depth + elif idx is Ellipsis or idx is None: + continue + else: + consumed_dims += 1 + + # Calculate how many dimensions the ellipsis should match + ellipsis_dims = tensor_ndim - consumed_dims + + # Check if ellipsis doesn't match any dimensions + if ellipsis_dims == 0: + # Check if we have advanced indexing on both sides of ellipsis + left_advanced = _has_advanced_indexing(index[:ellipsis_pos]) + right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :]) + + if left_advanced and right_advanced: + # This is the case where NumPy and PyTorch differ + # We need to ensure the advanced indices are treated as separated + new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :] + end_ndims = 1 + sum( + 1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice) + ) + + def squeeze_fn(x): + return x.squeeze(-end_ndims) + + def unsqueeze_fn(x): + if isinstance(x, torch.Tensor) and x.ndim >= end_ndims: + return x.unsqueeze(-end_ndims) + return x + + return new_index, squeeze_fn, unsqueeze_fn + + return index, lambda x: x, lambda x: x + + # Used to indicate that a parameter is unspecified (as opposed to explicitly # `None`) class _Unspecified: @@ -507,19 +590,23 @@ def neg_step(i, s): index = _upcast_int_indices(index) # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) - return ndarray(tensor.__getitem__(index)) + # Apply NumPy-compatible empty ellipsis behavior + index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim) + return maybe_squeeze(ndarray(tensor.__getitem__(index))) def __setitem__(self, index, value): index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) + # Apply NumPy-compatible empty ellipsis behavior + index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim) if not _dtypes_impl.is_scalar(value): value = normalize_array_like(value) value = _util.cast_if_needed(value, self.tensor.dtype) - return self.tensor.__setitem__(index, value) + return self.tensor.__setitem__(index, maybe_unsqueeze(value)) take = _funcs.take put = _funcs.put From e71bb021b9553ddc2db6cb8ea7bf8643552f09fc Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 16 Jul 2025 08:18:18 +0000 Subject: [PATCH 103/457] Add a periodic test for older NVIDIA driver (#158300) This is needed because of the botched landing of https://github.com/pytorch/pytorch/pull/156097 which crashed on older NVIDIA drivers `525.*`. I add a periodic job to install the `525.105.17` on CI, then run: 1. A smoke to make sure that CUDA can be initialized 2. And the whole the test suite on the older driver Pull Request resolved: https://github.com/pytorch/pytorch/pull/158300 Approved by: https://github.com/ngimel --- .ci/pytorch/test.sh | 6 ++++++ .github/workflows/_linux-test.yml | 2 ++ .github/workflows/periodic.yml | 30 ++++++++++++++++++++++++++++++ test/test_cuda.py | 3 +++ 4 files changed, 41 insertions(+) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 77004a1764850..a51a7e472c974 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -289,6 +289,12 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +if [[ "${TEST_CONFIG}" == "legacy_nvidia_driver" ]]; then + # Make sure that CUDA can be initialized + (cd test && python -c "import torch; torch.rand(2, 2, device='cuda')") + export USE_LEGACY_DRIVER=1 +fi + test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 469367d4d6841..d19a7b51938ef 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -164,6 +164,8 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver uses: pytorch/test-infra/.github/actions/setup-nvidia@main + with: + driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }} if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }} - name: Setup GPU_FLAG for docker run diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 7e70f4e21d0db..643d40e4d381b 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -82,6 +82,36 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }} secrets: inherit + linux-jammy-cuda12_4-py3_10-gcc11-build: + name: linux-jammy-cuda12.4-py3.10-gcc11 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.4-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 + test-matrix: | + { include: [ + { config: "legacy_nvidia_driver", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_4-py3_10-gcc11-test: + name: linux-jammy-cuda12.4-py3.10-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_4-py3_10-gcc11-build + - target-determination + with: + build-environment: linux-jammy-cuda12.4-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.test-matrix }} + secrets: inherit + linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml diff --git a/test/test_cuda.py b/test/test_cuda.py index aec3081014618..581c11c85ec10 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -6490,6 +6490,9 @@ def test_cuda_module_loading_env(self): self.assertEqual(val, "LAZY") +@unittest.skipIf( + os.environ.get("USE_LEGACY_DRIVER", None) == "1", "Doesn't work with older driver" +) class TestCompileKernel(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc") @unittest.skipIf(not TEST_CUDA, "No CUDA") From ea74fdd24aa7d98433231f4a3d75cfd241d8720e Mon Sep 17 00:00:00 2001 From: NikhilAPatel Date: Wed, 16 Jul 2025 06:30:42 +0000 Subject: [PATCH 104/457] [Inductor][Triton] Update TMA Compatibility Requirements (#157881) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157881 Approved by: https://github.com/Skylion007, https://github.com/drisspg --- torch/_inductor/utils.py | 76 +++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3d427fd7dd044..0b82dfda835a9 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1543,35 +1543,87 @@ def use_triton_template( ) -def use_triton_tma_template(*matrices: IRNode) -> bool: +def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool: + """ + Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints + that Triton relies on today. + * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + + A tensor is accepted when: + * 2 ≤ rank ≤ 5 + * dtype ∈ {FP16, BF16, FP8-E4M3FN} + * Every logical size ≥ 2 + * Base pointer 16-byte aligned + * All "outer" dims have 16-byte aligned strides + * The “inner” dim has stride 1 (contiguous) + * For FP8 tensors, inner dim ≥ 32 + """ from torch.utils._triton import has_triton_tma_device from .virtualized import V + def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool: + return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT) + def _is_tma_compatible(x: IRNode) -> bool: - if len(x.get_size()) != 2: + sizes = x.get_size() + strides = x.get_stride() + rank = len(sizes) + dtype = x.get_dtype() + itemsize = dtype.itemsize + + # 2 ≤ rank ≤ 5 + if rank < 2 or rank > 5: return False - dtype = x.get_dtype() + # dtype ∈ {FP16, BF16, FP8-E4M3FN} if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False - layout = x.get_layout() - transposed = layout.is_transposed() - if not (layout.is_contiguous() or transposed): + # Base pointer 16-byte aligned + if x.get_name() in V.graph.unaligned_buffers: + return False + + if add_guards: + sizes_i = V.graph.sizevars.guard_int_seq(sizes) + strides_i = V.graph.sizevars.guard_int_seq(strides) + else: + sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes] + strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides] + + # Every logical size ≥ 2 + if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): return False - inner_dim = layout.size[1] - if transposed: - inner_dim = layout.size[0] + # Find the single contiguous (“inner”) dim + inner = [ + i + for i, st in enumerate(strides_i) + if V.graph.sizevars.statically_known_equals(st, 1) + ] + if len(inner) != 1: + return False + inner_idx = inner[0] - if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + # All "outer" dims must have 16-byte aligned strides + for i, st in enumerate(strides_i): + if i == inner_idx: + continue + if not _aligned(st * itemsize): + return False + + # Inner dim byte width must still be a multiple of 16 B + inner_dim = sizes_i[inner_idx] + if not _aligned(inner_dim * itemsize): + return False + + # FP8 special case: inner ≥ 32 + if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq( inner_dim, 32 ): return False - inner_bytes = inner_dim * dtype.itemsize - return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + return True return ( config.triton.enable_persistent_tma_matmul From 9d184bda2f190a3ba72a4a0d95e1fde26d6bfc08 Mon Sep 17 00:00:00 2001 From: Hari Krishna Sai Kodali Date: Wed, 16 Jul 2025 09:37:00 +0000 Subject: [PATCH 105/457] add device generalization support for distributed tests (#156796) MOTIVATION To generalize Distributed test cases for non-CUDA devices CHANGES - test/distributed/checkpoint/test_fsspec.py - test/distributed/checkpoint/test_state_dict.py - test/distributed/test_multi_threaded_pg.py Replaced hard coded device names with torch.accelerator.current_accelerator - torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py support for hccl backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/156796 Approved by: https://github.com/guangyey, https://github.com/ezyang --- test/distributed/checkpoint/test_fsspec.py | 21 +++-- .../distributed/checkpoint/test_state_dict.py | 94 ++++++++++--------- test/distributed/test_multi_threaded_pg.py | 4 +- .../_shard/sharded_tensor/__init__.py | 2 +- 4 files changed, 68 insertions(+), 53 deletions(-) diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py index af061e5b95c93..9d69d6d386a7e 100644 --- a/test/distributed/checkpoint/test_fsspec.py +++ b/test/distributed/checkpoint/test_fsspec.py @@ -18,7 +18,10 @@ from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -26,6 +29,10 @@ ) +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +BACKEND = torch.distributed.get_default_backend_for_device(device_type) + + def with_temp_dir( func: Optional[Callable] = None, ) -> Optional[Callable]: @@ -75,14 +82,14 @@ class TestFSSpec(ShardedTensorTestBase): def world_size(self) -> int: return 2 - @with_comms(init_rpc=False) + @with_comms(backend=BACKEND, init_rpc=False) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) - @requires_nccl() @with_temp_dir def test_fsspec(self): CHECKPOINT_DIR = self.temp_dir - model = FSDP(MyTestModule().cuda()) + model = FSDP(MyTestModule().to(device_type)) optim = torch.optim.Adam(model.parameters(), lr=0.1) model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() optim.step() @@ -99,7 +106,7 @@ def test_fsspec(self): planner=dcp.DefaultSavePlanner(), ) - model_2 = FSDP(MyTestModule().cuda()) + model_2 = FSDP(MyTestModule().to(device_type)) optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1) with FSDP.summon_full_params(model): @@ -149,9 +156,9 @@ def opt_at(opt, idx): opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"] ) - @with_comms(init_rpc=False) + @with_comms(backend=BACKEND, init_rpc=False) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) - @requires_nccl() @with_temp_dir def test_overwrite(self): t1, t2 = torch.randn(10), torch.randn(10) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 37bb6def9a94e..a42215e0ea0d6 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -62,6 +62,9 @@ from torch.utils._pytree import tree_all, tree_all_only +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -79,7 +82,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: - return min(4, torch.cuda.device_count()) + return min(4, torch.accelerator.device_count()) def _test_save_load( self, @@ -101,7 +104,7 @@ def _test_save_load( for d_optim in _dist_optim: d_optim.zero_grad() - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) model(batch).sum().backward() dist_model(batch).sum().backward() @@ -188,9 +191,9 @@ def _test_fsdp( def init_model_optim(): if use_dtensor: - device_mesh = init_device_mesh("cuda", (self.world_size,)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) if wrapping: @@ -198,7 +201,7 @@ def init_model_optim(): else: strategy = {UnitModule} if use_dtensor: - device_mesh = init_device_mesh("cuda", (self.world_size,)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy(strategy), @@ -258,7 +261,7 @@ def _test_fsdp2( foreach: bool = True, ): def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class( orig_model.parameters(), lr=1e-4, foreach=foreach ) @@ -295,7 +298,7 @@ def test_fsdp2(self) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) if use_composable: @@ -329,7 +332,7 @@ def _test_fsdp_ddp( test_frozen: bool = False, ) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) if test_frozen: for param in chain( orig_model.u1.parameters(), orig_model.u2.parameters() @@ -370,7 +373,7 @@ def test_fsdp_ddp(self) -> None: def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) model_copy = copy.deepcopy(orig_model) @@ -385,7 +388,7 @@ def test_single_gpu(self) -> None: self._test_single_gpu(torch.optim.AdamW) def _test_strict(self, parallelism: str) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) if parallelism == "DDP": model = DDP(model) else: @@ -422,8 +425,8 @@ def test_strict(self) -> None: def _test_cpu_offload_full_state_dict( self, optimizer_class: type[Optimizer] ) -> None: - orig_model = CompositeParamModel(device=torch.device("cuda")) - device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = CompositeParamModel(device=torch.device(device_type)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy({UnitModule}), @@ -499,7 +502,7 @@ def test_cpu_offload_full_state_dict(self) -> None: @skip_if_lt_x_gpu(1) def test_activation_ckpt_fqns_ddp(self) -> None: """Tests that activation checkpointing prefixes are removed from module names""" - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -518,7 +521,7 @@ def test_activation_ckpt_fqns_fsdp1(self) -> None: def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: """Tests that activation checkpointing prefixes are removed from module names""" - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -529,7 +532,7 @@ def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: @skip_if_lt_x_gpu(1) def test_extra_state(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) def get_extra_state(self): return "MyState" @@ -547,21 +550,21 @@ def set_extra_state(self, state): @skip_if_lt_x_gpu(1) def test_non_persistent_buffers(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) model.register_buffer( - "dont_save_me", torch.rand(100, device="cuda"), persistent=False + "dont_save_me", torch.rand(100, device=device_type), persistent=False ) target_model = copy.deepcopy(model) set_model_state_dict(target_model, get_model_state_dict(target_model)) self.assertEqual(model.state_dict(), get_model_state_dict(target_model)) def _test_broadcast_from_rank0(self, wrapper) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) optim = torch.optim.Adam(model.parameters()) fsdp_model = wrapper(copy.deepcopy(model)) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) model(batch).sum().backward() optim.step() states, optim_states = get_state_dict(model, optim) @@ -631,8 +634,8 @@ def check(equal): @with_comms @skip_if_lt_x_gpu(4) def test_broadcast_from_rank0(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - hsdp_device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + hsdp_device_mesh = init_device_mesh(device_type, (2, self.world_size // 2)) self.run_subtests( { "wrapper": [ @@ -654,8 +657,8 @@ def test_fsdp_root_not_initialized(self) -> None: # This test verifies that FSDP root is not initialized but we should # still be able to get the state_dict without errors because # fsdp_model.state_dict() will trigger the FSDP initialization. - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) get_model_state_dict(fsdp_model) @@ -668,10 +671,9 @@ def test_optim_state_dict_param_matching(self) -> None: # "initial_lr" is added to optim_state_dict, but not to the new optim # We test whether "initial_lr" appear in optim after # set_optimizer_state_dict. - device = "cuda" torch.manual_seed(0) model = nn.Sequential( - *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] + *[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)] ) for layer in model: fully_shard(layer) @@ -705,11 +707,11 @@ def test_optim_state_dict_param_matching(self) -> None: @with_comms @skip_if_lt_x_gpu(2) def test_flattened_osd(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh) fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) fsdp_model(batch).sum().backward() fsdp_optim.step() fsdp_optim.zero_grad() @@ -730,7 +732,7 @@ def test_flattened_osd(self) -> None: self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) def _test_deprecate_partial(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) model_state_dict1 = get_model_state_dict(model) model_state_dict1 = copy.deepcopy(model_state_dict1) @@ -783,8 +785,8 @@ def _test_deprecate_partial(self) -> None: self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) def _test_deprecate_fsdp_api(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) with self.assertWarnsRegex( FutureWarning, @@ -823,8 +825,8 @@ def forward(self, input): return output def init_model_optim(): - device_mesh = init_device_mesh("cuda", (self.world_size,)) - orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device(device_type)) orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) @@ -905,8 +907,12 @@ def test_setting_meta_device_model_broadcasting_and_memory(self) -> None: self.assertEqual(cpu_model_value, meta_model_value) # Memory allocated and reserved are lower due to the change at _distribute_tensors # from view to clone. This test would fail if with view due to higher memory cost. - memory_allocated = torch.cuda.memory_allocated(0) / 1024 / 1024 - memory_reserved = torch.cuda.memory_reserved(0) / 1024 / 1024 + memory_allocated = ( + torch.get_device_module(device_type).memory_allocated(0) / 1024 / 1024 + ) + memory_reserved = ( + torch.get_device_module(device_type).memory_reserved(0) / 1024 / 1024 + ) self.assertTrue(memory_allocated <= 384) self.assertTrue(memory_reserved <= 768) @@ -942,11 +948,11 @@ def test_multi_device_load_model_state_dict(self) -> None: meta_submodel = nn.Linear(4, 4, bias=False) with torch.device("cpu"): cpu_submodel = nn.Linear(4, 4, bias=False) - with torch.device("cuda"): - cuda_submodel = nn.Linear(4, 4, bias=False) + with torch.device(device_type): + acc_submodel = nn.Linear(4, 4, bias=False) - two_device_model_with_meta = nn.Sequential(meta_submodel, cuda_submodel) - two_device_model_without_meta = nn.Sequential(cpu_submodel, cuda_submodel) + two_device_model_with_meta = nn.Sequential(meta_submodel, acc_submodel) + two_device_model_without_meta = nn.Sequential(cpu_submodel, acc_submodel) with torch.device("cpu"): model_to_set = nn.Sequential( @@ -974,7 +980,7 @@ def test_multi_device_load_model_state_dict(self) -> None: def test_state_dict_with_hook_on_keys(self) -> None: with torch.device("meta"): metamodel = FusionEmbedding(4, 4, 4) - with torch.device("cuda"): + with torch.device(device_type): gpumodel = FusionEmbeddingWithHook(4, 4, 4) gpumodel_state_dict = get_model_state_dict(gpumodel) with self.assertRaisesRegex(RuntimeError, "Missing key"): @@ -995,8 +1001,8 @@ def __init__(self): def forward(self, x): return self.fc1(self.fc(x)) - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = TestModel().cuda() + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = TestModel().to(device_type) parallelize_module( model, device_mesh, @@ -1014,7 +1020,7 @@ def _test_multi( optim = torch.optim.AdamW(**optim_kwargs) optim.zero_grad() - model(torch.randn(64, 64).cuda()).sum().backward() + model(torch.randn(64, 64, device=device_type)).sum().backward() optim.step() optim.zero_grad() @@ -1067,7 +1073,7 @@ def setUp(self) -> None: @skip_if_lt_x_gpu(1) def test_no_dist(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) optim = torch.optim.AdamW(model.parameters(), lr=1e-4) self.assertFalse(dist.is_initialized()) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 196cebb1617c0..7ca6d25ad1c97 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -25,6 +25,8 @@ from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + DEFAULT_WORLD_SIZE = 4 @@ -330,7 +332,7 @@ def backward(ctx, grad_output): return grad_output * result x = torch.tensor( - [dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True + [dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True ) x = MyFunc.apply(x) x.sum().backward() diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index 838c5fd01adfc..60c744ac1a84c 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -22,7 +22,7 @@ def world_size(self): return TEST_GPU_NUM def init_pg(self, backend="nccl"): - if backend not in ["nccl", "gloo", "mpi"]: + if backend not in ["nccl", "gloo", "mpi", "hccl"]: raise RuntimeError(f"Backend {backend} not supported!") dist.init_process_group( From ac706bfc7f942b8a97401486a840dd8f6452f5cb Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Tue, 15 Jul 2025 23:49:00 -0700 Subject: [PATCH 106/457] disable multi kernel rocm (#158299) Fixes https://github.com/pytorch/pytorch/issues/158274 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158299 Approved by: https://github.com/huydhn --- test/inductor/test_multi_kernel.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index 2966f3ac1d91b..f576016cf08c5 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocm, skipIfXpu, ) from torch.testing._internal.inductor_utils import ( @@ -98,6 +99,8 @@ def test_softmax(self, expect_multi_kernel=True): self.assertFalse(_contains_multi_kernel_code(wrapper_code)) @requires_triton() + # TODO: bobrenjc93 to fix multi-kernel for ROCM + @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_triton_gemm(self): def fn(x, y): @@ -123,6 +126,8 @@ def fn(x, y): self.assertTrue(_contains_multi_kernel_code(wrapper_code)) @requires_triton() + # TODO: bobrenjc93 to fix multi-kernel for ROCM + @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_triton_relu_fused_gemm(self): def fn(x, y): From 0a99b026d6bd0f67dc2c0a20fe3228ddc4144854 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Wed, 16 Jul 2025 10:52:47 +0000 Subject: [PATCH 107/457] [Docker builds] Move from Miniconda to Miniforge (#158370) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is related to: https://www.anaconda.com/legal/terms/terms-of-service Trying to fix outage with docker builds. https://github.com/pytorch/pytorch/actions/runs/16298993712/job/46033590799 Rocm and XPU builds since they use Miniforge are not affected ``` #22 ERROR: process "/bin/sh -c bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt /opt/conda/requirements-docs.txt" did not complete successfully: exit code: 1 ------ > [base 14/42] RUN bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt /opt/conda/requirements-docs.txt: 11.93 CondaToSNonInteractiveError: Terms of Service have not been accepted for the following channels. Please accept or remove them before proceeding: 11.93 • https://repo.anaconda.com/pkgs/main 11.93 • https://repo.anaconda.com/pkgs/r 11.93 11.93 To accept a channel's Terms of Service, run the following and replace `CHANNEL` with the channel name/URL: 11.93 ‣ conda tos accept --override-channels --channel CHANNEL ``` Hence solution is: 1. using `` conda tos accept --override-channels --channel defaults`` 2. use Miniforge instead of Miniconda. Using solution 2. Solution Tried that don't work: 1. Using ``CONDA_ALWAYS_YES = true `` 4. Using older version of miniconda ``` [Miniconda3-py310_25.5.1-0-Linux-x86_64.sh](https://repo.anaconda.com/miniconda/Miniconda3-py310_25.5.1-0-Linux-x86_64.sh) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158370 Approved by: https://github.com/seemethere Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com> --- .ci/docker/common/install_conda.sh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 185837b7e98a2..481de54a50f2c 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -4,12 +4,8 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then - BASE_URL="https://repo.anaconda.com/miniconda" - CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]] || [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore - CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" - fi + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) @@ -21,7 +17,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then exit 1 ;; esac - mkdir -p /opt/conda chown jenkins:jenkins /opt/conda From 55d888a616be3c94d8e4073b4d1580541692997d Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Wed, 16 Jul 2025 12:31:14 +0000 Subject: [PATCH 108/457] Add framework for explanations for common CUDA errors (#158395) As popularly requested in user groups. Test plan: ``` import torch a = torch.randn(10000) device = torch.device('cuda:1') a = a.to(device) ``` Before: ``` Traceback (most recent call last): File "/data/users/raymo/pytorch/test/cuda.py", line 6, in a = a.to(device) ^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: invalid device ordinal CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` After: ``` Traceback (most recent call last): File "/data/users/raymo/pytorch/test/cuda.py", line 6, in a = a.to(device) ^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: invalid device ordinal GPU device may be out of range, do you have enough GPUs? CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158395 Approved by: https://github.com/aorenste Co-authored-by: Aaron Orenstein --- c10/cuda/CUDAException.cpp | 5 +++-- c10/cuda/CUDAMiscFunctions.cpp | 24 +++++++++++++++++------- c10/cuda/CUDAMiscFunctions.h | 3 ++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 05f00e43a2a7c..40cacff550976 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -28,8 +28,9 @@ void c10_cuda_check_implementation( std::string check_message; #ifndef STRIP_ERROR_MESSAGES check_message.append("CUDA error: "); - check_message.append(cudaGetErrorString(cuda_error)); - check_message.append(c10::cuda::get_cuda_check_suffix()); + const char* error_string = cudaGetErrorString(cuda_error); + check_message.append(error_string); + check_message.append(c10::cuda::get_cuda_check_suffix(error_string)); check_message.append("\n"); if (include_device_assertions) { check_message.append(c10_retrieve_device_side_assertion_info()); diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index cc6519728f1ea..5de9996d2eb76 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,21 +1,31 @@ #include #include +#include +#include +#include namespace c10::cuda { // NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) -const char* get_cuda_check_suffix() noexcept { +std::string get_cuda_check_suffix(const char* error_string) noexcept { + std::string suffix; + + // Explain common CUDA errors + if (strstr(error_string, "invalid device ordinal")) { + suffix.append("\nGPU device may be out of range, do you have enough GPUs?"); + } + static auto device_blocking_flag = c10::utils::check_env("CUDA_LAUNCH_BLOCKING"); static bool blocking_enabled = (device_blocking_flag.has_value() && device_blocking_flag.value()); - if (blocking_enabled) { - return ""; - } else { - return "\nCUDA kernel errors might be asynchronously reported at some" - " other API call, so the stacktrace below might be incorrect." - "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"; + if (!blocking_enabled) { + suffix.append( + "\nCUDA kernel errors might be asynchronously reported at some" + " other API call, so the stacktrace below might be incorrect." + "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"); } + return suffix; } std::mutex* getFreeMutex() { static std::mutex cuda_free_mutex; diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index dc3fced770ba8..c79a22bea231d 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -5,8 +5,9 @@ #include #include +#include namespace c10::cuda { -C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +C10_CUDA_API std::string get_cuda_check_suffix(const char*) noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda From 51a708ffc679b13f99e4c7cf19bc00082a3266a6 Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Wed, 16 Jul 2025 12:36:51 +0000 Subject: [PATCH 109/457] [nativert] libtorch kernel registry (#157150) Summary: att Test Plan: ci Rollback Plan: Differential Revision: D77451703 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157150 Approved by: https://github.com/georgiaphillips, https://github.com/henryoier --- build_variables.bzl | 4 + c10/core/impl/SizesAndStrides.h | 4 + torch/nativert/kernels/KernelRegistry.cpp | 1380 +++++++++++++++++ torch/nativert/kernels/KernelRegistry.h | 122 ++ torch/nativert/kernels/NativeKernels.cpp | 113 ++ torch/nativert/kernels/PrimKernelRegistry.cpp | 6 +- 6 files changed, 1626 insertions(+), 3 deletions(-) create mode 100644 torch/nativert/kernels/KernelRegistry.cpp create mode 100644 torch/nativert/kernels/KernelRegistry.h create mode 100644 torch/nativert/kernels/NativeKernels.cpp diff --git a/build_variables.bzl b/build_variables.bzl index 99290d5318cdc..d90f3cfafa3e6 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -625,6 +625,10 @@ libtorch_nativert_sources = [ "torch/nativert/executor/memory/AliasAnalyzer.cpp", "torch/nativert/executor/memory/LayoutPlanner.cpp", "torch/nativert/executor/memory/LayoutManager.cpp", + "torch/nativert/kernels/KernelRegistry.cpp", + "torch/nativert/kernels/NativeKernels.cpp", + "torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp", + "torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp", ] torch_mobile_tracer_sources = [ diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index b8a4de1c2d890..6cc87e1d6be3e 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -64,6 +64,10 @@ class C10_API SizesAndStrides { storageBytes(size_))); } + bool operator!=(const SizesAndStrides& other) const { + return !(*this == other); + } + SizesAndStrides& operator=(const SizesAndStrides& rhs) { if (this == &rhs) { return *this; diff --git a/torch/nativert/kernels/KernelRegistry.cpp b/torch/nativert/kernels/KernelRegistry.cpp new file mode 100644 index 0000000000000..2632b7886804c --- /dev/null +++ b/torch/nativert/kernels/KernelRegistry.cpp @@ -0,0 +1,1380 @@ +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace at::native { + +static void repeat_out( + at::Tensor& result, + const Tensor& self, + IntArrayRef repeats) { + TORCH_CHECK( + repeats.size() >= static_cast(self.dim()), + "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); + + // Add new leading dimensions to the tensor if the + // number of target dimensions is larger than the + // number of source dimensions. + int64_t num_new_dimensions = repeats.size() - self.dim(); + DimVector padded_size(num_new_dimensions, 1); + padded_size.insert( + padded_size.end(), self.sizes().begin(), self.sizes().end()); + DimVector target_size(repeats.size()); + bool zero_tensor = false; + for (const auto idx : c10::irange(repeats.size())) { + if (repeats[idx] == 0) { + zero_tensor = true; + } + target_size[idx] = padded_size[idx] * repeats[idx]; + } + + // return an empty tensor if one of the repeat dimensions is zero + at::native::resize_(result, target_size, std::nullopt); + if (zero_tensor) { + return; + } + + Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size); + Tensor urtensor = at::native::alias(result); + for (const auto i : c10::irange(xtensor.dim())) { + // can't unfold with step 0, so make sure step is at least 1 + // (it doesn't matter what it is in that case, because the size is 0). + urtensor = urtensor.unfold( + i, xtensor.size(i), std::max(xtensor.size(i), 1)); + } + + at::native::copy_(urtensor, xtensor.expand_as(urtensor)); +} + +static Tensor& c2_argmin_out( + Tensor& output, + const Tensor& input, + const int64_t dim, + const bool keepdim) { + const auto ndim = input.dim(); + int64_t dim_ = maybe_wrap_dim(dim, ndim); + TORCH_CHECK(dim_ >= 0 && dim_ < ndim); + + const auto in_dims = input.sizes(); + + c10::SmallVector out_dims; + out_dims.reserve(ndim); + int prev_size = 1; + int next_size = 1; + for (int i = 0; i < dim_; ++i) { + out_dims.push_back(in_dims[i]); + prev_size *= in_dims[i]; + } + if (keepdim) { + out_dims.push_back(1); + } + for (auto i = dim_ + 1; i < ndim; ++i) { + out_dims.push_back(in_dims[i]); + next_size *= in_dims[i]; + } + at::native::resize_(output, out_dims, std::nullopt); + + const auto n = in_dims[dim_]; + + if (next_size == 1) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto in_ptr = input.const_data_ptr(); + const auto out_ptr = output.mutable_data_ptr(); + // input is a [prev_size, n] tensor. + // output is a [prev_size,] tensor. + // Thus, access is contiguous/coalesced. + for (int i = 0; i < prev_size; ++i) { + auto v = std::min_element( + in_ptr + i * n, + in_ptr + (i + 1) * n, + [](scalar_t a, scalar_t b) { + // if a is nan, then a is *less* than b with LessOrNan + // semantics + if (at::_isnan(a)) { + return true; + } + // if a is not nan and b is nan, then a is not less than b + // with LessOrNan semantics otherwise, act normally. If `b` is + // NaN then a < b will always return false, so this is + // equivalent to the first snippet. + return a < b; + }); + out_ptr[i] = std::distance(in_ptr + i * n, v); + } + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto less_or_nan = native::detail::LessOrNan{}; + + const auto in_ptr = input.const_data_ptr(); + const auto out_ptr = output.mutable_data_ptr(); + + std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t)); + + for (int i = 0; i < prev_size; ++i) { + const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size; + for (int k = 1; k < n; ++k) { + for (int j = 0; j < next_size; ++j) { + int64_t* cur_out_ptr = out_ptr + i * next_size + j; + if (less_or_nan( + *cur_in_ptr, + in_ptr + [i * n * next_size + *cur_out_ptr * next_size + j], + *cur_out_ptr, + k)) { + *cur_out_ptr = k; + } + ++cur_in_ptr; + } + } + } + }); + } + return output; +} + +static Tensor& linear_out( + Tensor& output, + const Tensor& input, + const Tensor& weight, + const std::optional& bias_opt) { + TORCH_CHECK(!input.is_mkldnn()); + + auto bias = bias_opt.has_value() + ? c10::MaybeOwned::borrowed(*bias_opt) + : c10::MaybeOwned::owned(std::in_place); + + if (input.dim() == 2 && bias->defined()) { + // Fused op is marginally faster. + return at::cpu::addmm_out(output, *bias, input, weight.t()); + } + at::native::matmul_out(input, weight.t(), output); + if (bias->defined()) { + at::cpu::add_(output, *bias); + } + return output; +} + +static at::Tensor& mul_out( + at::Tensor& output, + const at::Tensor& self, + const at::Scalar& other) { + const auto& t_output = output.scalar_type(); + TORCH_CHECK(at::native::result_type(self, other) == t_output); + + auto self_sizes = self.sizes(); + at::native::resize_(output, self_sizes, std::nullopt); + + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, t_output, "mul_Scalar_out", [&]() { + using output_t = scalar_t; + output_t* output_ptr = output.mutable_data_ptr(); + + const int64_t num_elements = self.numel(); + const void* self_ptr = self.data_ptr(); + + at::parallel_for(0, num_elements, 1, [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; ++i) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, other.type(), "mul_Scalar_other", [&]() { + using other_t = scalar_t; + + output_t other_casted = static_cast( + reinterpret_cast(other.data_ptr())[0]); + + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, + kBFloat16, + self.scalar_type(), + "mul_Scalar_self", + [&]() { + using self_t = scalar_t; + + output_ptr[i] = + other_casted * + static_cast( + reinterpret_cast(self_ptr)[i]); + }); + }); + } + }); + }); + + return output; +} + +} // namespace at::native + +namespace torch::nativert { + +C10_DEFINE_REGISTRY( + StaticallyDispatchedCPUKernelRegistry, + OpKernel, + const Node*, + c10::Device); + +namespace { + +// device & pin_memory matter only when CUDA is enabled. +static bool hasTensorWithOptions( + const c10::IValue& ivalue, + std::optional dtype, + std::optional layout) { + if (!ivalue.isTensor()) { + return false; + } + const auto& tensor = ivalue.toTensor(); + if (dtype == tensor.dtype().toScalarType() && + layout == tensor.options().layout_opt()) { + return true; + } + VLOG(1) << "tensor exists, but tensor options were different"; + return false; +} + +static bool hasTensorWithOptions( + const c10::IValue& ivalue, + std::optional dtype, + std::optional layout, + std::optional memory_format) { + return hasTensorWithOptions(ivalue, dtype, layout) && + (memory_format == ivalue.toTensor().options().memory_format_opt()); +} + +c10::MaybeOwned borrow_from_optional_tensor_ivalue( + const c10::IValue& iv) { + if (iv.isNone()) { + return c10::MaybeOwned::owned(std::in_place); + } + return c10::MaybeOwned::borrowed(iv.toTensor()); +} + +} // namespace + +REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Tensor", aten_remainder_Tensor, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::remainder(self, KernelInput(1).toTensor()); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::remainder_out(out, self, KernelInput(1).toTensor()); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Scalar", aten_remainder_Scalar, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::remainder(self, KernelInput(1).toScalar()); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::remainder_out(self, KernelInput(1).toScalar(), out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.matmul.default", aten_matmul, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::matmul(in0_t, in1_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::matmul_out(in0_t, in1_t, out_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.bmm.default", aten_bmm, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::bmm_out(out_t, in0_t, in1_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.abs.default", aten_abs, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::abs(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::abs_out(in0_t, out_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mul.Tensor", aten_mul, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::mul(in0_t, in1_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::mul_out(out_t, in0_t, in1_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mul.Scalar", aten_mul_Scalar, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toScalar(); + auto dtype = at::native::result_type(in0_t, in1_t); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t, dtype); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + KernelOutput(0) = at::native::mul_out(out_t, in0_t, in1_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.nan_to_num.default", aten_nan_to_num, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_d = KernelInput(1).toOptional(); + const auto in2_d = KernelInput(2).toOptional(); + const auto in3_d = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::nan_to_num(in0_t, in1_d, in2_d, in3_d); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.leaky_relu.default", aten_leaky_relu, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_s = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::leaky_relu(in0_t, in1_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + at::cpu::leaky_relu_out(out_t, in0_t, in1_s); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.relu.default", aten_relu, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::threshold_out(out_t, in0_t, 0, 0); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.clone.default", aten_clone, { + const auto& src = KernelInput(0).toTensor(); + const auto& optional_memory_format = + KernelInput(1).toOptional(); + auto memory_format = + optional_memory_format.value_or(c10::MemoryFormat::Preserve); + /* + disable out_variant of clone for case with stride = 0 and + memory formats other than preserve. Perform dynamic allocation + instead of memory reuse for simpler implementation. We could, + in principle, figure out copy of strides. + */ + if ((at::has_internal_overlap(src.unsafeGetTensorImpl()) == + at::MemOverlap::Yes) || + (memory_format != c10::MemoryFormat::Preserve)) { + KernelOutput(0) = at::native::clone(src, memory_format); + return; + } + if (KernelOutput(0).isNone()) { + if (src.is_non_overlapping_and_dense()) { + // Copy all strides + KernelOutput(0) = + at::empty_strided(src.sizes(), src.strides(), src.options()); + } else { + memory_format = src.suggest_memory_format(); + KernelOutput(0) = create_empty_from(src, memory_format); + } + } + auto& out_t = KernelOutput(0).toTensor(); + at::native::resize_impl_cpu_( + out_t.unsafeGetTensorImpl(), src.sizes(), src.strides()); + at::native::copy_(out_t, src, false); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.index.Tensor", aten_index, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_l = + at::native::toListOfOptionalTensors(KernelInput(1).toListRef()); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::index(in0_t, in1_l); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::index_out(out_t, in0_t, in1_l); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.index_select.default", aten_index_select, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::index_select_cpu_(self, dim, index); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::index_select_out_cpu_(self, dim, index, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.pow.Tensor_Tensor", + aten_pow_Tensor_Tensor, + { + if (KernelOutput(0).isNone()) { + const auto& in0_t = KernelInput(0).toTensor(); + auto dtype = at::native::result_type(in0_t, KernelInput(1).toTensor()); + KernelOutput(0) = create_empty_from(in0_t, dtype); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out( + out_t, KernelInput(0).toTensor(), KernelInput(1).toTensor()); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.pow.Scalar", aten_pow_Scalar, { + if (KernelOutput(0).isNone()) { + const auto& in1_t = KernelInput(1).toTensor(); + auto dtype = at::native::result_type(KernelInput(0).toScalar(), in1_t); + KernelOutput(0) = at::native::empty_like( + in1_t, + dtype, + in1_t.options().layout_opt(), + in1_t.options().device_opt(), + in1_t.options().pinned_memory_opt(), + at::MemoryFormat::Preserve); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out(out_t, KernelInput(0).toScalar(), KernelInput(1).toTensor()); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.pow.Tensor_Scalar", + aten_pow_Tensor_Scalar, + { + if (KernelOutput(0).isNone()) { + const auto& in0_t = KernelInput(0).toTensor(); + auto dtype = at::native::result_type(in0_t, KernelInput(1).toScalar()); + KernelOutput(0) = at::native::empty_like( + in0_t, + dtype, + in0_t.options().layout_opt(), + in0_t.options().device_opt(), + in0_t.options().pinned_memory_opt(), + at::MemoryFormat::Preserve); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out( + out_t, KernelInput(0).toTensor(), KernelInput(1).toScalar()); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.sum.default", aten_sum_default, { + // if (n->inputs().size() != 2 && n->inputs().size() != 4) { + // return nullptr; + // } + const at::Tensor& self = KernelInput(0).toTensor(); + auto dtype = KernelInput(1).toOptional(); + std::vector dim = {}; + bool keepdim = false; + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sum(self, dim, keepdim, dtype); + } else { + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sum_out(out, self, dim, keepdim, dtype); + } +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sum.dim_IntList", aten_sum_dim_IntList, { + // if (n->inputs().size() != 2 && n->inputs().size() != 4) { + // return nullptr; + // } + const at::Tensor& self = KernelInput(0).toTensor(); + auto dim = KernelInput(1).toDimVector(); + auto keepdim = KernelInput(2).toBool(); + auto dtype = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sum(self, dim, keepdim, dtype); + } else { + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sum_out(out, self, dim, keepdim, dtype); + } +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mean.dim", aten_mean_dim, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toDimVector(); + const bool keepdim = KernelInput(2).toBool(); + const auto dtype = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + create_empty_from(self, dtype.value_or(self.dtype().toScalarType())); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mean_out(out, self, dim, keepdim, dtype); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.mean.default", aten_mean_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dtype = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + create_empty_from(self, dtype.value_or(self.dtype().toScalarType())); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mean_out(out, self, /*dim=*/{}, /*keepdim=*/false, dtype); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.max.other", aten_max_other, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::max(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::max_out(self, other, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.max.default", aten_max_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& value = KernelOutput(0).toTensor(); + fastResizeToZero(value); + at::cpu::amax_out(value, self); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sign.Tensor", aten_sign_Tensor, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sign(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sign_out(out_t, in0_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.log.default", aten_log, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::log(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::log_out(out_t, in0_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sub.Tensor", aten_sub_Tensor, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sub(in0_t, in1_t, alpha); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sub_out(out_t, in0_t, in1_t, alpha); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.sub.Scalar", aten_sub, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = + at::native::wrapped_scalar_tensor(KernelInput(1).toScalar()); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sub(in0_t, in1_t, alpha); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sub_out(out_t, in0_t, in1_t, alpha); +}); + +// TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor +// Missing Test Coverage +REGISTER_CPU_KERNEL( + "torch.ops.aten.clamp_min.default", + aten_clamp_min_default, + { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_s = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp_min(in0_t, in1_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::clamp_min_out(out_t, in0_t, in1_s); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.argmin.default", aten_argmin, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toOptional(); + const auto keepdim = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::argmin(in0_t, dim, keepdim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + if (in0_t.is_contiguous() && dim.has_value()) { + at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim); + return; + } + at::cpu::argmin_out(out_t, in0_t, dim, keepdim); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.softmax.int", aten_softmax_int, { + const auto& in_t = KernelInput(0).toTensor(); + const auto& dim = KernelInput(1).toInt(); + const auto& dtype = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::softmax(in_t, dim, dtype); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + auto half_to_float = in_t.scalar_type() == at::ScalarType::Half && + dtype == at::ScalarType::Float; + at::cpu::_softmax_out(out_t, in_t, dim, half_to_float); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dtype", + aten_norm_ScalarOpt_dtype, + { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + c10::IntArrayRef{}, + false, + KernelInput(2).toScalarType(), + out_t); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.full.default", aten_full, { + const auto& size = KernelInput(0).toDimVector(); + const auto fill_value = KernelInput(1).toScalar(); + const auto dtype = KernelInput(2).toOptional(); + const auto layout = KernelInput(3).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + const auto device = KernelInput(4).toOptional(); + const auto pin_memory = KernelInput(5).toOptional(); + KernelOutput(0) = + at::native::full(size, fill_value, dtype, layout, device, pin_memory); + return; + } + KernelOutput(0) = + at::native::full_out(size, fill_value, KernelOutput(0).toTensor()); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ones.default", aten_ones, { + const auto size = KernelInput(0).toDimVector(); + if (KernelOutput(0).isNone()) { + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + const auto device = KernelInput(3).toOptional(); + const auto pin_memory = KernelInput(4).toOptional(); + KernelOutput(0) = at::native::ones(size, dtype, layout, device, pin_memory); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::ones_out(size, out_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.ones_like.default", aten_ones_like, { + const auto& self = KernelInput(0).toTensor(); + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + const auto device = KernelInput(3).toOptional(); + const auto pin_memory = KernelInput(4).toOptional(); + const auto memory_format = KernelInput(5).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout, memory_format)) { + KernelOutput(0) = at::native::ones_like( + self, dtype, layout, device, pin_memory, memory_format); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::ones_out(self.sizes(), out_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.zeros.default", aten_zeros, { + const auto size = KernelInput(0).toDimVector(); + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + KernelOutput(0) = at::compositeexplicitautograd::zeros( + size, dtype, layout, std::nullopt, std::nullopt); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::compositeexplicitautograd::zeros_out(out_t, size); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_norm.default", + aten_linalg_norm_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(2).toDimVector(); + const auto keepdim = KernelInput(3).toBool(); + const auto dtype = KernelInput(4).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_norm( + self, KernelInput(1).toOptional(), dim, keepdim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_norm_out( + self, + KernelInput(1).toOptional(), + dim, + keepdim, + dtype, + out); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.linalg_norm.ord_str", aten_linalg_norm, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(2).toDimVector(); + const auto keepdim = KernelInput(3).toBool(); + const auto dtype = KernelInput(4).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_norm( + self, KernelInput(1).toStringView(), dim, keepdim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_norm_out( + self, KernelInput(1).toStringRef(), dim, keepdim, dtype, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cat.default", aten_cat, { + const auto inputs = KernelInput(0).toTensorVector(); + TORCH_CHECK(!inputs.empty(), "concat expects non-empty tensor list"); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cat(inputs, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cat_outf(inputs, dim, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.cumsum.default", aten_cumsum, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto dtype = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cumsum(self, dim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cumsum_out(out, self, dim, dtype); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.nonzero.default", aten_nonzero, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::nonzero_cpu(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::nonzero_out_cpu(self, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.addmm.default", aten_addmm, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + const auto& in2_t = KernelInput(2).toTensor(); + const auto in3_s = KernelInput(3).toScalar(); + const auto in4_s = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::addmm(in0_t, in1_t, in2_t, in3_s, in4_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.narrow_copy.default", aten_narrow_copy, { + const auto& self = KernelInput(0).toTensor(); // self + const auto dim = KernelInput(1).toInt(); // dim + int64_t start = 0; + if (KernelInput(2).isScalar()) { + start = KernelInput(2).toInt(); + } else { + auto& t = KernelInput(2).toTensor(); + start = t.item(); + } + auto length = KernelInput(3).toInt(); // length + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::native::narrow_copy_dense_cpu(self, dim, start, length); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::narrow_copy_dense_cpu_out(self, dim, start, length, out); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.repeat.default", aten_repeat, { + const auto& self = KernelInput(0).toTensor(); + const auto repeats = KernelInput(1).toDimVector(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::repeat(self, repeats); + return; + } + at::Tensor& out = KernelOutput(0).toTensor(); + at::native::repeat_out(out, self, repeats); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.max.dim", aten_max_dim, { + const auto& self = KernelInput(0).toTensor(); + auto dim = KernelInput(1).toInt(); + const auto keepdim = KernelInput(2).toBool(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + + if (KernelOutput(1).isNone()) { + KernelOutput(1) = create_empty_from(self, at::kLong); + } + + auto& values = KernelOutput(0).toTensor(); + auto& indices = KernelOutput(1).toTensor(); + fastResizeToZero(values); + fastResizeToZero(indices); + at::cpu::max_out(values, indices, self, dim, keepdim); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.layer_norm.default", aten_layer_norm, { + // ignore KernelInput(5): `bool cudnn_enable=True` + const auto& input_t = KernelInput(0).toTensor(); + const auto normalized_shape = KernelInput(1).toDimVector(); + float eps = KernelInput(4).toDouble(); + + c10::MaybeOwned weight_maybe_owned = + borrow_from_optional_tensor_ivalue(KernelInput(2)); + const at::Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = + borrow_from_optional_tensor_ivalue(KernelInput(3)); + const at::Tensor& bias = *bias_maybe_owned; + + auto M_N = at::native::_check_layer_norm_inputs( + input_t, normalized_shape, weight, bias); + auto M = M_N.first; + auto N = M_N.second; + auto X = input_t.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + auto beta = bias.expect_contiguous(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); + } else { + at::native::resize_(KernelOutput(0).toTensor(), X->sizes(), std::nullopt); + } + at::Tensor& out = KernelOutput(0).toTensor(); + at::native::layer_norm_cpu_out(out, *X, *gamma, *beta, eps, M, N); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dim_dtype", + aten_norm_ScalarOpt_dim_dtype, + { + const auto& in0_t = KernelInput(0).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + KernelInput(2).toDimVector(), // dim + KernelInput(3).toBool(), // keepdim + KernelInput(4).toScalarType(), // dtype + out_t); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dim", + aten_norm_ScalarOpt_dim, + { + const auto& in0_t = KernelInput(0).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + KernelInput(2).toDimVector(), // dim + KernelInput(3).toBool(), // keepdim + out_t); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.full_like.default", aten_full_like, { + const auto in1_s = KernelInput(1).toScalar(); + const auto& in0_t = KernelInput(0).toTensor(); + const auto dtype = KernelInput(2).toOptional(); + const auto layout = KernelInput(3).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + const auto device = KernelInput(4).toOptional(); + const auto pin_memory = KernelInput(5).toOptional(); + const auto memory_format = KernelInput(6).toOptional(); + + KernelOutput(0) = at::native::empty_like( + in0_t, dtype, layout, device, pin_memory, memory_format); + } + auto& out_t = KernelOutput(0).toTensor(); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::fill_out(out_t, in1_s); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.linear.default", aten_linear, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + auto in2_t = KernelInput(2).toOptional(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linear(in0_t, in1_t, in2_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::linear_out(out_t, in0_t, in1_t, in2_t); +}); + +REGISTER_CPU_KERNEL("torch.ops.aten.where.self", aten_where, { + const auto& cond = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto& other = KernelInput(2).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::where_self_out(cond, self, other, out); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_rowwise_offsets.default", + quantized_embedding_bag_byte_rowwise_offsets, + { + const auto& weight = KernelInput(0).toTensor(); + const auto& indices = KernelInput(1).toTensor(); + const auto offsets = KernelInput(2).toOptional(); + const auto pruned_weights = KernelInput(5).toBool(); + const auto per_sample_weights = KernelInput(6).toOptional(); + const auto compressed_indices_mapping = + KernelInput(7).toOptional(); + const auto include_last_offset = KernelInput(8).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(weight, at::kFloat); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::embedding_bag_byte_rowwise_offsets_out( + out_t, + weight, + indices, + offsets, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_4bit_rowwise_offsets.default", + quantized_embedding_bag_4bit_rowwise_offsets, + { + const auto& weight = KernelInput(0).toTensor(); + const auto& indices = KernelInput(1).toTensor(); + const auto offsets = KernelInput(2).toOptional(); + const auto pruned_weights = KernelInput(5).toBool(); + const auto per_sample_weights = KernelInput(6).toOptional(); + const auto compressed_indices_mapping = + KernelInput(7).toOptional(); + const auto include_last_offset = KernelInput(8).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(weight, at::kFloat); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::embedding_bag_4bit_rowwise_offsets_out( + out_t, + weight, + indices, + offsets, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear_dynamic_fp16.default", + quantized_linear_dynamic_fp16, + { + const auto& in_0 = KernelInput(0).toTensor(); + + if (auto& out_0 = KernelOutput(0); out_0.isNone()) { + out_0 = create_empty_from(in_0, at::kFloat); + } + + auto& out_0 = KernelOutput(0).toTensor(); + fastResizeToZero(out_0); + + KernelInput(1).toCustomClass()->apply_dynamic_out( + in_0, out_0, /* reduce_range= */ false); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear_relu_dynamic_fp16.default", + quantized_linear_relu_dynamic_fp16, + { + const auto& in_0 = KernelInput(0).toTensor(); + + if (auto& out_0 = KernelOutput(0); out_0.isNone()) { + out_0 = create_empty_from(in_0, at::kFloat); + } + + auto& out_0 = KernelOutput(0).toTensor(); + fastResizeToZero(out_0); + + KernelInput(1) + .toCustomClass() + ->apply_dynamic_out(in_0, out_0, /* reduce_range= */ false) + .relu_(); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear.default", + quantized_linear_default, + { + const auto& in_0 = KernelInput(0).toTensor(); + const auto w_prepack = + KernelInput(1).toCustomClass(); + const auto output_scale = KernelInput(2).toDouble(); + const auto output_zero_point = KernelInput(3).toInt(); + if (auto& out_t = KernelOutput(0); out_t.isNone()) { + out_t = at::native::empty_affine_quantized( + {0}, + c10::kQUInt8, + std::nullopt, + c10::kCPU, + false, + output_scale, + output_zero_point, + std::nullopt); + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + w_prepack->apply_out(in_0, output_scale, output_zero_point, out_tensor); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.logit.default", aten_logit, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_d = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::logit_out(in0_t, in1_d, out_t); +}); + +REGISTER_CPU_KERNEL( + "torch.ops.aten.slice_scatter.default", + aten_slice_scatter, + { + const auto& self = KernelInput(0).toTensor(); + const auto& src = KernelInput(1).toTensor(); + const int64_t dim = KernelInput(2).toInt(); + const auto& start = KernelInput(3).toOptional(); + const auto& end = KernelInput(4).toOptional(); + int64_t step = KernelInput(5).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::slice_scatter_out(out, self, src, dim, start, end, step); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_unpack.default", + quantized_embedding_bag_byte_unpack_default, + { + const auto& weight = KernelInput(0).toTensor(); + if (auto& out = KernelOutput(0); out.isNone()) { + out = at::empty( + {}, + weight.options().dtype(at::kFloat), + weight.suggest_memory_format()); + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + at::native::qembeddingbag_byte_unpack_out(out_tensor, weight); + }); + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_prepack.default", + embedding_bag_byte_prepack_default, + { + const auto& weight = KernelInput(0).toTensor(); + if (auto& out_t = KernelOutput(0); out_t.isNone()) { + KernelOutput(0) = at::native::qembeddingbag_byte_prepack(weight); + return; + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + at::native::qembeddingbag_byte_prepack_out(out_tensor, weight); + }); + +REGISTER_CPU_KERNEL("torch.ops.aten.stack.default", aten_stack, { + const auto& inputs = KernelInput(0).toTensorVector(); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::_stack_cpu(inputs, dim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::_stack_out_cpu(inputs, dim, out_t); +}); + +class OpKernel_aten__to_copy : public C10Kernel { + public: + explicit OpKernel_aten__to_copy(const Node* node, c10::Device device) + : C10Kernel( + node, + device, + torch::nativert::OpKernelKind::kStaticDispatchKernel, + torch::nativert::AliasingSpec{ + {/* input_idx = */ 0, /* output_idx = */ 0}}) { + dtype_ = attribute(1).toOptional(); + layout_ = attribute(2).toOptional(); + device_ = attribute(3).toOptional(); + pin_memory_ = attribute(4).toOptional(); + non_blocking_ = attribute(5).toBool(); + memory_format_ = attribute(6).toOptional(); + + has_memory_format_ = memory_format_.has_value(); + + if (memory_format_.has_value()) { + TORCH_CHECK( + memory_format_.value() != c10::MemoryFormat::ChannelsLast && + memory_format_.value() != c10::MemoryFormat::ChannelsLast3d, + "Static Kernel for aten._to_copy doesn't correctly handle the ChannelsLast(3d) memory format. If you are running into this error, please report to nativert oncall."); + } + + if (device_.has_value()) { + TORCH_CHECK( + device_.value().is_cpu(), + "Static kernel for aten._to_copy only supports CPU device, but got ", + device_.value()); + } + } + + void computeInternal(ExecutionFrame& executionFrame) const override final { + const auto& self = KernelInput(0).toTensor(); + auto& out = KernelOutput(0); + + // skip if the _to_copy is a no-op + if (dtype_.has_value() && self.dtype() == dtype_.value() && + !has_memory_format_ && !device_.has_value() && !layout_.has_value()) { + if (out.isNone()) { + out = at::native::alias(self); + return; + } + + auto* in_t = self.unsafeGetTensorImpl(); + auto* out_t = out.toTensor().unsafeGetTensorImpl(); + + // it's possible that the input storage has been updated + if (!out_t->storage().is_alias_of(in_t->storage())) { + out_t->set_storage_keep_dtype(in_t->storage()); + } + + // in case in was re-sized/strided from the prev. impl + // we need to make sure the metadata is consistent between + // in_t and out_t + + if (in_t->storage_offset() != out_t->storage_offset()) { + out_t->set_storage_offset(in_t->storage_offset()); + } + + if (in_t->sizes_and_strides() != out_t->sizes_and_strides()) { + out_t->set_sizes_and_strides(self.sizes(), self.strides()); + } + + return; + } + + std::optional memory_format = + c10::MemoryFormat::Preserve; + if (has_memory_format_) { + memory_format = memory_format_.value_or(c10::MemoryFormat::Preserve); + } + + bool copy_strides = false; + if (memory_format == c10::MemoryFormat::Preserve) { + if (self.is_non_overlapping_and_dense()) { + memory_format = std::nullopt; + copy_strides = true; + } else { + memory_format = self.suggest_memory_format(); + } + } + + bool need_to_allocate_output = true; + if (out.isTensor()) { + const auto& existing_output = out.toTensor(); + if ((has_memory_format_ && + !existing_output.is_contiguous( + memory_format.value_or(c10::MemoryFormat::Contiguous)))) { + need_to_allocate_output = true; + } else { + need_to_allocate_output = false; + } + } + + // See Note [Explicit nullopt MemoryFormat argument] + // Can't use size {0} if memory_format is ChannelLast + if (need_to_allocate_output) { + out = at::detail::empty_cpu( + self.sizes(), + dtype_.value_or(self.scalar_type()), + layout_, + device_, + std::nullopt, + memory_format); + } else { + if (has_memory_format_) { + memory_format = memory_format_.value_or(c10::MemoryFormat::Preserve); + } else { + memory_format = c10::MemoryFormat::Preserve; + } + } + + copy_strides = copy_strides || + (memory_format == c10::MemoryFormat::Preserve && + self.is_non_overlapping_and_dense()); + + auto& out_t = out.toTensor(); + fastResizeToZero(out_t); + at::native::to_copy_out( + out_t, self, non_blocking_, copy_strides, memory_format); + } + + private: + std::optional dtype_; + std::optional layout_; + std::optional device_; + std::optional pin_memory_; + bool non_blocking_ = false; + std::optional memory_format_; + bool has_memory_format_; +}; + +C10_REGISTER_TYPED_CLASS( + StaticallyDispatchedCPUKernelRegistry, + "torch.ops.aten._to_copy.default", + OpKernel_aten__to_copy) + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelRegistry.h b/torch/nativert/kernels/KernelRegistry.h new file mode 100644 index 0000000000000..03293871fef29 --- /dev/null +++ b/torch/nativert/kernels/KernelRegistry.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include + +namespace torch::nativert { + +TORCH_DECLARE_REGISTRY( + StaticallyDispatchedCPUKernelRegistry, + OpKernel, + const Node*, + c10::Device); + +#define REGISTER_CPU_KERNEL(name, id, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kStaticDispatchKernel) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +#define ALIASING_SPEC(...) __VA_ARGS__ + +#define REGISTER_ALIASING_CPU_KERNEL(name, id, aliasing_spec, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kNativeStaticDispatchKernel, \ + aliasing_spec) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +#define REGISTER_NATIVE_CPU_KERNEL(name, id, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kNativeStaticDispatchKernel) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +inline at::Tensor create_empty_from(const at::Tensor& t) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + t.device(), + std::nullopt, + std::nullopt); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::ScalarType dtype) { + return at::detail::empty_cpu( + {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt); +} + +inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + device, + std::nullopt, + std::nullopt); +} +inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + layout, + t.device(), + std::nullopt, + std::nullopt); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::MemoryFormat memory_format) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + t.device(), + std::nullopt, + memory_format); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::ScalarType dtype, + c10::MemoryFormat memory_format) { + return at::detail::empty_cpu( + {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format); +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/NativeKernels.cpp b/torch/nativert/kernels/NativeKernels.cpp new file mode 100644 index 0000000000000..1f847863070ac --- /dev/null +++ b/torch/nativert/kernels/NativeKernels.cpp @@ -0,0 +1,113 @@ +#include + +#include +#include +#include + +namespace torch::nativert { + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.slice.Tensor", aten_slice_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& dim = KernelInput(1).toInt(); + const auto& start = KernelInput(2).toOptional(); + const auto& end = KernelInput(3).toOptional(); + const auto& step = KernelInput(4).toInt(); + KernelOutput(0) = at::native::slice(self, dim, start, end, step); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.sym_size.int", aten_sym_size_int, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + auto& out = KernelOutput(0); + TORCH_CHECK(dim >= 0 && dim < self.dim(), "Invalid dimension"); + out = self.sym_size(dim); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.reshape.default", aten_reshape, { + const auto& self = KernelInput(0).toTensor(); + const auto& shape = KernelInput(1).toIntVector(); + KernelOutput(0) = at::native::reshape(self, shape); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.view.default", aten_view, { + const auto& self = KernelInput(0).toTensor(); + const auto& size = KernelInput(1).toIntVector(); + KernelOutput(0) = at::native::view(self, size); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.permute.default", aten_permute, { + const auto& self = KernelInput(0).toTensor(); + const auto& dims = KernelInput(1).toDimVector(); + KernelOutput(0) = at::native::permute(self, dims); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.select.int", aten_select, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto index = KernelInput(2).toInt(); + KernelOutput(0) = at::native::select(self, dim, index); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.split.Tensor", aten_split_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto split_size = KernelInput(1).toInt(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = at::native::split(self, split_size, dim); +}); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.split_with_sizes.default", + aten_split_with_sizes, + { + const auto& self = KernelInput(0).toTensor(); + const auto& split_sizes = KernelInput(1).toIntList(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = + at::native::split_with_sizes(self, split_sizes.vec(), dim); + }); + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.tensor_split.sections", + aten_tensor_split_sections, + { + const auto& self = KernelInput(0).toTensor(); + const auto sections = KernelInput(1).toInt(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = + at::native::tensor_split_sections_symint(self, sections, dim); + }); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.item.default", aten_item, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::item(self); +}); + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.narrow.default", aten_narrow, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + int64_t start = 0; + if (KernelInput(2).isScalar()) { + start = KernelInput(2).toInt(); + } else { + auto& t = KernelInput(2).toTensor(); + start = t.item(); + } + const auto length = KernelInput(3).toInt(); + TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + auto cur_size = self.sizes()[dim]; + if (start != cur_size && start < 0) { + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + KernelOutput(0) = at::native::slice(self, dim, start, start + length, 1); +}); + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp index e6f69634a71b8..80421bae77597 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.cpp +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -57,7 +57,7 @@ class OpKernel_prim_listpack : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.ListPack", - OpKernel_prim_listpack); + OpKernel_prim_listpack) REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { RECORD_USER_SCOPE("nativert::OpKernel_prim_listunpack"); @@ -114,7 +114,7 @@ class OpKernel_variadic_concat : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.VarConcat", - OpKernel_variadic_concat); + OpKernel_variadic_concat) namespace { @@ -158,6 +158,6 @@ class OpKernel_variadic_stack : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.VarStack", - OpKernel_variadic_stack); + OpKernel_variadic_stack) } // namespace torch::nativert From bc65253369933160a2da3fc786d027a572faf6b7 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 16 Jul 2025 10:00:57 +0000 Subject: [PATCH 110/457] Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037) cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack. The scaling format is still detected from the sizes of the scale tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037 Approved by: https://github.com/eqy, https://github.com/drisspg --- aten/src/ATen/ceil_div.h | 12 +- aten/src/ATen/cuda/CUDABlas.cpp | 116 ++++++---- aten/src/ATen/cuda/CUDABlas.h | 14 +- aten/src/ATen/cuda/tunable/GemmCommon.h | 8 +- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 63 ++++-- aten/src/ATen/cuda/tunable/TunableGemm.h | 5 +- aten/src/ATen/native/cuda/Blas.cpp | 239 +++++++++++---------- test/test_matmul_cuda.py | 101 +++++++-- 8 files changed, 356 insertions(+), 202 deletions(-) diff --git a/aten/src/ATen/ceil_div.h b/aten/src/ATen/ceil_div.h index 37d67b232a22c..9e69873b1bd9d 100644 --- a/aten/src/ATen/ceil_div.h +++ b/aten/src/ATen/ceil_div.h @@ -7,8 +7,12 @@ namespace at { /** Computes ceil(a / b) */ -template >> -C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { +template < + typename T, + typename U, + typename = std::enable_if_t< + std::conjunction_v, std::is_integral>>> +C10_ALWAYS_INLINE C10_HOST_DEVICE std::common_type_t ceil_div(T a, U b) { return (a + b - 1) / b; } @@ -16,8 +20,8 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest multiple of b */ -template -C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { +template +C10_ALWAYS_INLINE C10_HOST_DEVICE std::common_type_t round_up(T a, U b) { return ceil_div(a, b) * b; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d009520d05ab8..acb1d5ed8b0da 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1843,6 +1843,69 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); +int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) { + switch (scaling_type) { + case ScalingType::BlockWise1x32: + TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + + case ScalingType::BlockWise1x16: + TORCH_CHECK(scale_dtype == kFloat8_e4m3fn); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + + case ScalingType::RowWise: + TORCH_CHECK(scale_dtype == kFloat); +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F; +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // Return the default, since in old hipblaslt this is activated via + // the SCALE_POINTER_VEC_EXT attributed. + return 0; +#else + TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + + case ScalingType::BlockWise1x128: + TORCH_CHECK(scale_dtype == kFloat); + TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling") +#if CUDA_VERSION >= 12090 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; +#else + TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + + case ScalingType::BlockWise128x128: + TORCH_CHECK(scale_dtype == kFloat); + TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling") +#if CUDA_VERSION >= 12090 + return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#else + TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + +case ScalingType::TensorWise: + TORCH_CHECK(scale_dtype == kFloat); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; +#else + // The macro isn't defined, thus we inline its value. + return 0; +#endif // if CUDA_VERSION >= 12080 + + default: + TORCH_CHECK(false); + return -1; + } +} + void scaled_gemm( char transa, char transb, @@ -1854,19 +1917,20 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, + ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, + ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum, - bool use_rowwise) { + bool use_fast_accum) { // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. #if CUDA_VERSION >= 11080 || defined(USE_ROCM) @@ -1879,19 +1943,15 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(USE_ROCM) -#if defined(HIPBLASLT_OUTER_VEC) - // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F -#elif defined(HIPBLASLT_VEC_EXT) - if (use_rowwise) { + // hipblaslt supported row-wise before cublas, and did so their own way (via + // the SCALE_POINTERSs), but then migrated to match how cublas does it (via + // the SCALE_MODEs). Here we check for this early custom mode. +#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) + if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } -#else - // rowwise isn't supported using older hipblaslt - TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); -#endif -#endif // defined(USE_ROCM) +#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { @@ -1931,30 +1991,14 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } - if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { -#if CUDA_VERSION >= 12080 - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - } else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) { -#if CUDA_VERSION >= 12080 - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { -#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) - // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - } + // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, + // but we must invoke get_scale_mode anyways to trigger the version checks. + int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); + int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); +#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); +#endif CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index b1dac2162dc42..5021917fe0950 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -136,6 +136,15 @@ void int8_gemm( int32_t* result_ptr, int64_t result_ld); +enum class ScalingType : std::uint8_t { + TensorWise, // fp32 scales + RowWise, // fp32 scales + BlockWise1x16, // fp8_e4m3fn scales + BlockWise1x32, // fp8_e8m0fnu scales + BlockWise1x128, // fp32 scales + BlockWise128x128, // fp32 scales +}; + void scaled_gemm( char transa, char transb, @@ -147,19 +156,20 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, + ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, + ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum, - bool use_rowwise); + bool use_fast_accum); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 6f896f1a22bfc..6d19907aba4ad 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -29,6 +29,8 @@ namespace at::cuda::tunable { +using at::cuda::blas::ScalingType; + enum class BlasOp { N = 0, T = 1 @@ -598,7 +600,8 @@ struct ScaledGemmParams : OpParams { // // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", - transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, + transa, transb, m, n, k, lda, ldb, ldc, + a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise, bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); } @@ -673,11 +676,13 @@ struct ScaledGemmParams : OpParams { int64_t lda{}; ScalarType a_dtype{}; ScalarType a_scale_dtype{}; + ScalingType a_scaling_type{}; const void* b{}; const void* b_scale_ptr{}; int64_t ldb{}; ScalarType b_dtype{}; ScalarType b_scale_dtype{}; + ScalingType b_scaling_type{}; const void* bias_ptr{}; ScalarType bias_dtype{}; void* c{}; @@ -686,7 +691,6 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype{}; void* amax_ptr{}; bool use_fast_accum{}; - bool use_rowwise{}; private: bool duplicate_inputs_{false}; }; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 32fb7c2774fff..809ba51009f0a 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -206,23 +206,43 @@ float GetBetaFromParams(const ScaledGemmParams* params) { } template -bool GetUseRowwiseFromParams(const GemmParams* params) { - return false; +ScalingType GetAScalingTypeFromParams(const GemmParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { - return false; +ScalingType GetBScalingTypeFromParams(const GemmParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { - return false; +ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { - return params->use_rowwise; +ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { + return params->a_scaling_type; +} + +template +ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { + return params->b_scaling_type; } template @@ -489,23 +509,24 @@ class HipblasltGemmOp : public Callable { const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { -#ifdef HIPBLASLT_VEC_EXT - if (GetUseRowwiseFromParams(params)) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); - } - else + hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; + hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; + if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { +#if defined(HIPBLASLT_OUTER_VEC) + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(HIPBLASLT_VEC_EXT) + a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; #endif - { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } -#ifdef HIPBLASLT_OUTER_VEC - if (GetUseRowwiseFromParams(params)) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + if (GetBScalingTypeFromParams(params) == ScalingType::RowWise) { +#if defined(HIPBLASLT_OUTER_VEC) matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); - } +#elif defined(HIPBLASLT_VEC_EXT) + b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; #endif + } + matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); + matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d7e2835b1b109..d941c230630c4 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -96,19 +96,20 @@ class DefaultScaledGemmOp : public Callable> { params->lda, params->a_dtype, params->a_scale_dtype, + params->a_scaling_type, params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->b_scale_dtype, + params->b_scaling_type, params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, - params->use_fast_accum, - params->use_rowwise); + params->use_fast_accum); return OK; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index c46e1cc633119..377be5d40aab8 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -99,6 +100,7 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } } +using at::cuda::blas::ScalingType; /** * @brief Prepares matrices for CUBLAS operation @@ -140,7 +142,9 @@ struct cublasCommonArgs { Tensor& c, const std::optional& scale_a = std::nullopt, const std::optional& scale_b = std::nullopt, - const std::optional& scale_result = std::nullopt) { + const std::optional& scale_result = std::nullopt, + const std::optional& scaling_choice_a = std::nullopt, + const std::optional& scaling_choice_b = std::nullopt) { bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); @@ -152,8 +156,10 @@ struct cublasCommonArgs { // as B.T @ A.T, check transpose_result to determine if we flip the scales scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); + scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); + scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; } if (scale_result) { @@ -199,7 +205,9 @@ struct cublasCommonArgs { void* scale_matb_ptr = nullptr; void* scale_result_ptr = nullptr; std::optional scale_mata_dtype; + std::optional scaling_mata_type; std::optional scale_matb_dtype; + std::optional scaling_matb_type; std::optional scale_result_dtype; }; } // namespace @@ -1075,133 +1083,114 @@ static bool _scaled_mm_is_fnuz() { namespace{ -enum class ScalingType : std::uint8_t { - TensorWise, - RowWise, - BlockWise, - Error -}; /* * Scaling Type Determination: * --------------------------- * Conditions and corresponding Scaling Types: * - * - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`: + * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: * - Returns BlockWise (with additional size checks). * - * - If scale_a.numel() == 1 && scale_b.numel() == 1: + * - Else if scale.numel() == 1: * - Returns TensorWise. * - * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1: * - Returns RowWise. * + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128: + * - Returns BlockWise 1x128. + * + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128: + * - Returns BlockWise 128x128. + * * - Otherwise: * - Returns Error. */ -// Validates the scale tensors to scaled_mm -// And returns the type of scaling/which kernel to use -ScalingType get_scaling_type( - const at::Tensor& scale_a, - const at::Tensor& scale_b, - int64_t dim_m, - int64_t dim_k, - int64_t dim_n) { - // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types) - if ((scale_a.scalar_type() == scale_b.scalar_type()) && - ((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) { - const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn; - - // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements - // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements. - const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32; - - constexpr int64_t BLOCK_SIZE_MN = 128; - - // adjust for fp4x2 packing if necessary - const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k; - - auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; - auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K); - auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4; - - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - - // Check expected sizes for block-wise scaling - auto expected_a_size = - BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks; - auto expected_b_size = - BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks; - - TORCH_CHECK(scale_a.numel() == expected_a_size, - "For BlockWise scaling: Expected scale_a size to be ", - expected_a_size, " but got ", scale_a.numel()); - TORCH_CHECK(scale_b.numel() == expected_b_size, - "For BlockWise scaling: Expected scale_b size to be ", - expected_b_size, " but got ", scale_b.numel()); - - TORCH_CHECK( - scale_a.is_contiguous() && scale_b.is_contiguous(), - "For BlockWise scaling: Both scale_a and scale_b must be contiguous"); - - return ScalingType::BlockWise; - } - // Both Per-Tensor and Row-wise scaling expect fp32 tensors - TORCH_CHECK( - scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); +using at::cuda::blas::ScalingType; - // Check the singluar scale case for per-tensor scaling - if (scale_a.numel() == 1 && scale_b.numel() == 1) { - return ScalingType::TensorWise; - } +bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1; +} - // For non-TensorWise scaling, enforce 2D input tensors - TORCH_CHECK( - scale_a.dim() == 2 && scale_b.dim() == 2, - "For non-TensorWise scaling, scale tensors must be 2-dimensional, " - "but got scale_a.dim()=", - scale_a.dim(), - " and scale_b.dim()=", - scale_b.dim()); - - // Check for RowWise scaling - if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && - scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ - (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) - TORCH_CHECK( - scale_a.is_contiguous() && scale_b.is_contiguous(), - "Both scale_a and scale_b must be contiguous for RowWise scaling."); - return ScalingType::RowWise; -#else - TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); - return ScalingType::Error; -#endif +bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == t.size(0) && scale.size(1) == 1 + && scale.is_contiguous()); +} + +// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales +bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) { + // Multiply t.size(1) by 2 to adjust for fp4x2 packing + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4) + && scale.is_contiguous()); +} + +// 1x16 blocks for microscaled fp8 data and fp8_e8m0fnu scales +bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) { + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4) + && scale.is_contiguous()); +} + +bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128) + && scale.stride(0) == 1 && scale.stride(1) == t.size(0)); +} + +bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128) + && scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1); +} + +bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) { + switch (desired_scaling) { + case ScalingType::TensorWise: + return is_tensorwise_scaling(t, scale); + case ScalingType::RowWise: + return is_rowwise_scaling(t, scale); + case ScalingType::BlockWise1x16: + return is_blockwise_1x16_scaling(t, scale); + case ScalingType::BlockWise1x32: + return is_blockwise_1x32_scaling(t, scale); + case ScalingType::BlockWise1x128: + return is_blockwise_1x128_scaling(t, scale); + case ScalingType::BlockWise128x128: + return is_blockwise_128x128_scaling(t, scale); + default: + TORCH_CHECK(false); + return false; } +} - // If we reach here, the input doesn't match any valid scaling type +std::pair get_joint_scaling( + std::initializer_list> options, + const at::Tensor& a, const at::Tensor& b, + const at::Tensor& scale_a, const at::Tensor& scale_b) { + for (auto [lhs, rhs] : options) { + if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) { + return {lhs, rhs}; + } + } TORCH_CHECK( - false, - "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " - "For RowWise scaling, scale_a should be (", - dim_m, - ", 1) and scale_b should be (1, ", - dim_n, - "). " - "Got scale_a.size()=(", - scale_a.size(0), - ", ", - scale_a.size(1), - ") and ", - "scale_b.size()=(", - scale_b.size(0), - ", ", - scale_b.size(1), - ")"); - - return ScalingType::Error; + false, + "Invalid scaling configuration.\n" + "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n" + "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n" + "- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n" + "- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n" + "- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n" + "- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n" + "Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ", + "b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides() + ); } } // namespace @@ -1243,9 +1232,21 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - // Check what type of scaling we are doing based on inputs - ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1)); - TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + // Check what type of scaling we are doing based on inputs. This list is sorted + // by decreasing priority. We prefer "simpler" schemes as they are supported + // more broadly (more GPU archs, more CUDA versions) and because they are more + // efficient. This tends to matter only for small matmuls (e.g., 1x1x128). + auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( + { + std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), + std::make_pair(ScalingType::RowWise, ScalingType::RowWise), + std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128), + std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128), + std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128), + std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32), + std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16) + }, + mat1, mat2, scale_a, scale_b); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -1316,7 +1317,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, #ifndef USE_ROCM // We are doing row-wise scaling auto dprops = at::cuda::getCurrentDeviceProperties(); - if (scaling_choice == ScalingType::RowWise + if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise && (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( @@ -1330,7 +1331,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } #else - if (scaling_choice == ScalingType::RowWise) { + if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. Tensor b = mat2; if (_scaled_mm_is_fnuz()) { @@ -1345,7 +1346,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } #endif - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result); + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1422,10 +1423,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.a_scale_ptr = args.scale_mata_ptr; params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); + params.a_scale_dtype = args.scale_mata_dtype.value(); + params.a_scaling_type = args.scaling_mata_type.value(); params.b = args.matb->data_ptr(); params.b_scale_ptr = args.scale_matb_ptr; params.ldb = args.ldb; params.b_dtype = args.matb->scalar_type(); + params.b_scale_dtype = args.scale_matb_dtype.value(); + params.b_scaling_type = args.scaling_matb_type.value(); params.bias_ptr = bias ? bias->data_ptr(): nullptr; params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; params.c = args.result->data_ptr(); @@ -1433,7 +1438,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.ldc = args.result_ld; params.c_dtype = out_dtype_; params.use_fast_accum = use_fast_accum; - params.use_rowwise = scaling_choice == ScalingType::RowWise; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) } @@ -1467,19 +1471,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.lda, args.mata->scalar_type(), args.scale_mata_dtype.value(), + args.scaling_mata_type.value(), args.matb->data_ptr(), args.scale_matb_ptr, args.ldb, args.matb->scalar_type(), args.scale_matb_dtype.value(), + args.scaling_matb_type.value(), bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), args.scale_result_ptr, args.result_ld, out_dtype_, - use_fast_accum, - scaling_choice == ScalingType::RowWise); + use_fast_accum); } return out; diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 31f36681bc3a4..30526c2a84826 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -785,7 +785,7 @@ def amax_to_scale( if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -806,6 +806,20 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): return amax_to_scale(amax, float8_dtype, x.dtype) +def tensor_to_scale_block( + x: torch.Tensor, + float8_dtype: torch.dtype, + block_outer: int, + block_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = torch.finfo(float8_dtype).max / amax + x = x.mul(scale).to(float8_dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + return x, scale + def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -814,6 +828,17 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) +def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: + x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1)) + y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1)) + x_fp32 = x.to(torch.float) / x_scale[:, None, :, None] + y_fp32 = y.to(torch.float) / y_scale[:, None, :, None] + x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1) + y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1) + out_fp32 = torch.mm(x_fp32, y_fp32) + + return out_fp32.to(out_dtype) + def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -1237,11 +1262,7 @@ def test_float8_error_messages(self, device) -> None: y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( - RuntimeError, - re.escape( - "For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1252,11 +1273,7 @@ def test_float8_error_messages(self, device) -> None: ) with self.assertRaisesRegex( - RuntimeError, - re.escape( - " For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1266,22 +1283,18 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, - re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N), device="cuda"), + scale_b=torch.ones((N, N, 1), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, - re.escape( - "Both scale_a and scale_b must be contiguous for RowWise scaling." - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1346,6 +1359,58 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 9), + "cuBLAS blockwise scaling added in CUDA 12.9", + ) + @parametrize("output_dtype", [torch.bfloat16, torch.float32]) + @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) + def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block): + torch.manual_seed(42) + + x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3) + y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3) + + x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128) + y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128) + + # 1x128 blocks need scales to be outer-dim-major + if lhs_block == 1: + x_scales = x_scales.t().contiguous().t() + if rhs_block == 1: + y_scales = y_scales.t().contiguous().t() + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated_block( + x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype + ) + + cosine_sim = torch.nn.functional.cosine_similarity( + out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0 + ) + self.assertGreaterEqual(float(cosine_sim), 0.999) + + if output_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 6e-1, 7e-2 + else: + atol, rtol = 7e-1, 2e-3 + + self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + # One last check against the full-precision reference, to ensure we + # didn't mess up the scaling itself and made the test trivial. + cosine_sim = torch.nn.functional.cosine_similarity( + out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0 + ) + self.assertGreaterEqual(float(cosine_sim), 0.999) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) From 2043f6911e795a5dcdbf326bc0cea5f33e7771d7 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 14 Jul 2025 08:05:13 -0700 Subject: [PATCH 111/457] [BE] Rename libnvshmem_extension to libtorch_nvshmem (#158234) `libnvshmem_extension.so` creates an illusion that it is a shared library from NVSHMEM. But indeed it is built from torch source code, for symmetric tensor infrastructure and operations, though leveraging NVSHMEM APIs. Thus this PR renames `libnvshmem_extension.so` to `libtorch_nvshmem.so`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158234 Approved by: https://github.com/albanD --- caffe2/CMakeLists.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 7d0a98fbd33be..ed7a80f4b4b0c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1027,24 +1027,24 @@ elseif(USE_CUDA) # Linking with nvshmem requires the source binary to be built with -rdc # which is not viable for libtorch_cuda. So we isolate the linking of - # nvshmem in nvshmem_extension. - add_library(nvshmem_extension SHARED + # nvshmem in torch_nvshmem. + add_library(torch_nvshmem SHARED "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp" ) - set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON) - target_compile_options(nvshmem_extension PRIVATE $<$:-rdc=true>) - target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") - target_link_libraries(nvshmem_extension PRIVATE + set_target_properties(torch_nvshmem PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + target_compile_options(torch_nvshmem PRIVATE $<$:-rdc=true>) + target_compile_options(torch_nvshmem PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") + target_link_libraries(torch_nvshmem PRIVATE ${NVSHMEM_HOST_LIB} ${NVSHMEM_DEVICE_LIB} ) target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) - target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM) - target_link_libraries(torch_cuda PRIVATE nvshmem_extension) - install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib) + target_compile_definitions(torch_nvshmem PUBLIC USE_NVSHMEM) + target_link_libraries(torch_cuda PRIVATE torch_nvshmem) + install(TARGETS torch_nvshmem EXPORT Caffe2Targets DESTINATION lib) else() message(STATUS "NVSHMEM not found, not building with NVSHMEM support.") endif() From 5763ec5f8d11df5eea962bedc74563394c0e273f Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 14 Jul 2025 08:05:18 -0700 Subject: [PATCH 112/457] [BE] Replace lib with TORCH_INSTALL_LIB_DIR (#158235) Their values are actually the same. Just staying in line with other `INSTALL` commands. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158235 Approved by: https://github.com/Skylion007 ghstack dependencies: #158234 --- caffe2/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index ed7a80f4b4b0c..1edcb36e94f9c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1044,7 +1044,7 @@ elseif(USE_CUDA) target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) target_compile_definitions(torch_nvshmem PUBLIC USE_NVSHMEM) target_link_libraries(torch_cuda PRIVATE torch_nvshmem) - install(TARGETS torch_nvshmem EXPORT Caffe2Targets DESTINATION lib) + install(TARGETS torch_nvshmem EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") else() message(STATUS "NVSHMEM not found, not building with NVSHMEM support.") endif() From 0b19d463d963a0b2ee5558d2c0bb79b2cbff6e64 Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 16 Jul 2025 06:44:30 -0700 Subject: [PATCH 113/457] forward fix lint (#158448) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158448 Approved by: https://github.com/adamomainz --- torch/_inductor/fx_passes/fuse_attention.py | 2 +- torch/fx/traceback.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 3e8bd56b32140..5f449eb496642 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -584,7 +584,7 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): def _sfdp_pattern_24(query, key, value, attention_mask): """ this pattern is for MBartForCausalLM/PLBartForCausalLM. - attn_mask has a differnt dtype with QKV. + attn_mask has a different dtype with QKV. there is no scale in sdpa. """ bs = query.size(0) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index e57e89ea8d4b5..836b41d661859 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -51,6 +51,8 @@ def __init__(self, name: str, target: str, graph_id: int): action: list["NodeSourceAction"] from_node: list["NodeSource"] node_info: Optional["NodeInfo"] + _dict: Optional[dict[str, Any]] + _action_string: Optional[str] def __init__( self, From 9513b9d03fa8950ba5d2b59cc0b1a1aab3a41c06 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Jul 2025 15:04:10 +0000 Subject: [PATCH 114/457] Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)" This reverts commit bc65253369933160a2da3fc786d027a572faf6b7. Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/lw due to OSX failures are real ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3079042171)) --- aten/src/ATen/ceil_div.h | 12 +- aten/src/ATen/cuda/CUDABlas.cpp | 116 ++++------ aten/src/ATen/cuda/CUDABlas.h | 14 +- aten/src/ATen/cuda/tunable/GemmCommon.h | 8 +- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 63 ++---- aten/src/ATen/cuda/tunable/TunableGemm.h | 5 +- aten/src/ATen/native/cuda/Blas.cpp | 239 ++++++++++----------- test/test_matmul_cuda.py | 101 ++------- 8 files changed, 202 insertions(+), 356 deletions(-) diff --git a/aten/src/ATen/ceil_div.h b/aten/src/ATen/ceil_div.h index 9e69873b1bd9d..37d67b232a22c 100644 --- a/aten/src/ATen/ceil_div.h +++ b/aten/src/ATen/ceil_div.h @@ -7,12 +7,8 @@ namespace at { /** Computes ceil(a / b) */ -template < - typename T, - typename U, - typename = std::enable_if_t< - std::conjunction_v, std::is_integral>>> -C10_ALWAYS_INLINE C10_HOST_DEVICE std::common_type_t ceil_div(T a, U b) { +template >> +C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { return (a + b - 1) / b; } @@ -20,8 +16,8 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE std::common_type_t ceil_div(T a, U b) { Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest multiple of b */ -template -C10_ALWAYS_INLINE C10_HOST_DEVICE std::common_type_t round_up(T a, U b) { +template +C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { return ceil_div(a, b) * b; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index acb1d5ed8b0da..d009520d05ab8 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1843,69 +1843,6 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) { - switch (scaling_type) { - case ScalingType::BlockWise1x32: - TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - - case ScalingType::BlockWise1x16: - TORCH_CHECK(scale_dtype == kFloat8_e4m3fn); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - - case ScalingType::RowWise: - TORCH_CHECK(scale_dtype == kFloat); -#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F; -#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) - // Return the default, since in old hipblaslt this is activated via - // the SCALE_POINTER_VEC_EXT attributed. - return 0; -#else - TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - - case ScalingType::BlockWise1x128: - TORCH_CHECK(scale_dtype == kFloat); - TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling") -#if CUDA_VERSION >= 12090 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; -#else - TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - - case ScalingType::BlockWise128x128: - TORCH_CHECK(scale_dtype == kFloat); - TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling") -#if CUDA_VERSION >= 12090 - return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; -#else - TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - -case ScalingType::TensorWise: - TORCH_CHECK(scale_dtype == kFloat); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; -#else - // The macro isn't defined, thus we inline its value. - return 0; -#endif // if CUDA_VERSION >= 12080 - - default: - TORCH_CHECK(false); - return -1; - } -} - void scaled_gemm( char transa, char transb, @@ -1917,20 +1854,19 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, - ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, - ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum) { + bool use_fast_accum, + bool use_rowwise) { // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. #if CUDA_VERSION >= 11080 || defined(USE_ROCM) @@ -1943,15 +1879,19 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; - // hipblaslt supported row-wise before cublas, and did so their own way (via - // the SCALE_POINTERSs), but then migrated to match how cublas does it (via - // the SCALE_MODEs). Here we check for this early custom mode. -#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) - if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) { +#if defined(USE_ROCM) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) + if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } -#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) +#else + // rowwise isn't supported using older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); +#endif +#endif // defined(USE_ROCM) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { @@ -1991,14 +1931,30 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } - // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, - // but we must invoke get_scale_mode anyways to trigger the version checks. - int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); - int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); -#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); -#endif + if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { +#if CUDA_VERSION >= 12080 + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) { +#if CUDA_VERSION >= 12080 + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + } CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 5021917fe0950..b1dac2162dc42 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -136,15 +136,6 @@ void int8_gemm( int32_t* result_ptr, int64_t result_ld); -enum class ScalingType : std::uint8_t { - TensorWise, // fp32 scales - RowWise, // fp32 scales - BlockWise1x16, // fp8_e4m3fn scales - BlockWise1x32, // fp8_e8m0fnu scales - BlockWise1x128, // fp32 scales - BlockWise128x128, // fp32 scales -}; - void scaled_gemm( char transa, char transb, @@ -156,20 +147,19 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, - ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, - ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum); + bool use_fast_accum, + bool use_rowwise); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 6d19907aba4ad..6f896f1a22bfc 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -29,8 +29,6 @@ namespace at::cuda::tunable { -using at::cuda::blas::ScalingType; - enum class BlasOp { N = 0, T = 1 @@ -600,8 +598,7 @@ struct ScaledGemmParams : OpParams { // // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", - transa, transb, m, n, k, lda, ldb, ldc, - a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise, + transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); } @@ -676,13 +673,11 @@ struct ScaledGemmParams : OpParams { int64_t lda{}; ScalarType a_dtype{}; ScalarType a_scale_dtype{}; - ScalingType a_scaling_type{}; const void* b{}; const void* b_scale_ptr{}; int64_t ldb{}; ScalarType b_dtype{}; ScalarType b_scale_dtype{}; - ScalingType b_scaling_type{}; const void* bias_ptr{}; ScalarType bias_dtype{}; void* c{}; @@ -691,6 +686,7 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype{}; void* amax_ptr{}; bool use_fast_accum{}; + bool use_rowwise{}; private: bool duplicate_inputs_{false}; }; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 809ba51009f0a..32fb7c2774fff 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -206,43 +206,23 @@ float GetBetaFromParams(const ScaledGemmParams* params) { } template -ScalingType GetAScalingTypeFromParams(const GemmParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmParams* params) { + return false; } template -ScalingType GetBScalingTypeFromParams(const GemmParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { + return false; } template -ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { + return false; } template -ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { - return params->a_scaling_type; -} - -template -ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { - return params->b_scaling_type; +bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { + return params->use_rowwise; } template @@ -509,24 +489,23 @@ class HipblasltGemmOp : public Callable { const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { - hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; - hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; - if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { -#if defined(HIPBLASLT_OUTER_VEC) - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(HIPBLASLT_VEC_EXT) - a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; +#ifdef HIPBLASLT_VEC_EXT + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); + } + else #endif + { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } - if (GetBScalingTypeFromParams(params) == ScalingType::RowWise) { -#if defined(HIPBLASLT_OUTER_VEC) +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(HIPBLASLT_VEC_EXT) - b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; -#endif } - matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); - matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); +#endif } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d941c230630c4..d7e2835b1b109 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -96,20 +96,19 @@ class DefaultScaledGemmOp : public Callable> { params->lda, params->a_dtype, params->a_scale_dtype, - params->a_scaling_type, params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->b_scale_dtype, - params->b_scaling_type, params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, - params->use_fast_accum); + params->use_fast_accum, + params->use_rowwise); return OK; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 377be5d40aab8..c46e1cc633119 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -100,7 +99,6 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } } -using at::cuda::blas::ScalingType; /** * @brief Prepares matrices for CUBLAS operation @@ -142,9 +140,7 @@ struct cublasCommonArgs { Tensor& c, const std::optional& scale_a = std::nullopt, const std::optional& scale_b = std::nullopt, - const std::optional& scale_result = std::nullopt, - const std::optional& scaling_choice_a = std::nullopt, - const std::optional& scaling_choice_b = std::nullopt) { + const std::optional& scale_result = std::nullopt) { bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); @@ -156,10 +152,8 @@ struct cublasCommonArgs { // as B.T @ A.T, check transpose_result to determine if we flip the scales scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); - scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); - scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; } if (scale_result) { @@ -205,9 +199,7 @@ struct cublasCommonArgs { void* scale_matb_ptr = nullptr; void* scale_result_ptr = nullptr; std::optional scale_mata_dtype; - std::optional scaling_mata_type; std::optional scale_matb_dtype; - std::optional scaling_matb_type; std::optional scale_result_dtype; }; } // namespace @@ -1083,114 +1075,133 @@ static bool _scaled_mm_is_fnuz() { namespace{ +enum class ScalingType : std::uint8_t { + TensorWise, + RowWise, + BlockWise, + Error +}; /* * Scaling Type Determination: * --------------------------- * Conditions and corresponding Scaling Types: * - * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: + * - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`: * - Returns BlockWise (with additional size checks). * - * - Else if scale.numel() == 1: + * - If scale_a.numel() == 1 && scale_b.numel() == 1: * - Returns TensorWise. * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1: + * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: * - Returns RowWise. * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128: - * - Returns BlockWise 1x128. - * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128: - * - Returns BlockWise 128x128. - * * - Otherwise: * - Returns Error. */ -using at::cuda::blas::ScalingType; - -bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { - return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1; -} - -bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == t.size(0) && scale.size(1) == 1 - && scale.is_contiguous()); -} - -// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales -bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) { - // Multiply t.size(1) by 2 to adjust for fp4x2 packing - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn - && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4) - && scale.is_contiguous()); -} - -// 1x16 blocks for microscaled fp8 data and fp8_e8m0fnu scales -bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) { - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu - && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4) - && scale.is_contiguous()); -} - -bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128) - && scale.stride(0) == 1 && scale.stride(1) == t.size(0)); -} - -bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128) - && scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1); -} +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_k, + int64_t dim_n) { + // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types) + if ((scale_a.scalar_type() == scale_b.scalar_type()) && + ((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) { + const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn; + + // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements + // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements. + const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32; + + constexpr int64_t BLOCK_SIZE_MN = 128; + + // adjust for fp4x2 packing if necessary + const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k; + + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; + auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K); + auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4; + + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + + // Check expected sizes for block-wise scaling + auto expected_a_size = + BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks; + auto expected_b_size = + BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks; + + TORCH_CHECK(scale_a.numel() == expected_a_size, + "For BlockWise scaling: Expected scale_a size to be ", + expected_a_size, " but got ", scale_a.numel()); + TORCH_CHECK(scale_b.numel() == expected_b_size, + "For BlockWise scaling: Expected scale_b size to be ", + expected_b_size, " but got ", scale_b.numel()); + + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "For BlockWise scaling: Both scale_a and scale_b must be contiguous"); + + return ScalingType::BlockWise; + } + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); -bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) { - switch (desired_scaling) { - case ScalingType::TensorWise: - return is_tensorwise_scaling(t, scale); - case ScalingType::RowWise: - return is_rowwise_scaling(t, scale); - case ScalingType::BlockWise1x16: - return is_blockwise_1x16_scaling(t, scale); - case ScalingType::BlockWise1x32: - return is_blockwise_1x32_scaling(t, scale); - case ScalingType::BlockWise1x128: - return is_blockwise_1x128_scaling(t, scale); - case ScalingType::BlockWise128x128: - return is_blockwise_128x128_scaling(t, scale); - default: - TORCH_CHECK(false); - return false; + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; } -} -std::pair get_joint_scaling( - std::initializer_list> options, - const at::Tensor& a, const at::Tensor& b, - const at::Tensor& scale_a, const at::Tensor& scale_b) { - for (auto [lhs, rhs] : options) { - if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) { - return {lhs, rhs}; - } + // For non-TensorWise scaling, enforce 2D input tensors + TORCH_CHECK( + scale_a.dim() == 2 && scale_b.dim() == 2, + "For non-TensorWise scaling, scale tensors must be 2-dimensional, " + "but got scale_a.dim()=", + scale_a.dim(), + " and scale_b.dim()=", + scale_b.dim()); + + // Check for RowWise scaling + if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && + scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ + (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous for RowWise scaling."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif } + + // If we reach here, the input doesn't match any valid scaling type TORCH_CHECK( - false, - "Invalid scaling configuration.\n" - "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n" - "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n" - "- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n" - "- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n" - "- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n" - "- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n" - "Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ", - "b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides() - ); + false, + "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " + "For RowWise scaling, scale_a should be (", + dim_m, + ", 1) and scale_b should be (1, ", + dim_n, + "). " + "Got scale_a.size()=(", + scale_a.size(0), + ", ", + scale_a.size(1), + ") and ", + "scale_b.size()=(", + scale_b.size(0), + ", ", + scale_b.size(1), + ")"); + + return ScalingType::Error; } } // namespace @@ -1232,21 +1243,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - // Check what type of scaling we are doing based on inputs. This list is sorted - // by decreasing priority. We prefer "simpler" schemes as they are supported - // more broadly (more GPU archs, more CUDA versions) and because they are more - // efficient. This tends to matter only for small matmuls (e.g., 1x1x128). - auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( - { - std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), - std::make_pair(ScalingType::RowWise, ScalingType::RowWise), - std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128), - std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128), - std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128), - std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32), - std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16) - }, - mat1, mat2, scale_a, scale_b); + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -1317,7 +1316,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, #ifndef USE_ROCM // We are doing row-wise scaling auto dprops = at::cuda::getCurrentDeviceProperties(); - if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise + if (scaling_choice == ScalingType::RowWise && (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( @@ -1331,7 +1330,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } #else - if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { + if (scaling_choice == ScalingType::RowWise) { // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. Tensor b = mat2; if (_scaled_mm_is_fnuz()) { @@ -1346,7 +1345,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } #endif - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1423,14 +1422,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.a_scale_ptr = args.scale_mata_ptr; params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.a_scaling_type = args.scaling_mata_type.value(); params.b = args.matb->data_ptr(); params.b_scale_ptr = args.scale_matb_ptr; params.ldb = args.ldb; params.b_dtype = args.matb->scalar_type(); - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.b_scaling_type = args.scaling_matb_type.value(); params.bias_ptr = bias ? bias->data_ptr(): nullptr; params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; params.c = args.result->data_ptr(); @@ -1438,6 +1433,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.ldc = args.result_ld; params.c_dtype = out_dtype_; params.use_fast_accum = use_fast_accum; + params.use_rowwise = scaling_choice == ScalingType::RowWise; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) } @@ -1471,20 +1467,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.lda, args.mata->scalar_type(), args.scale_mata_dtype.value(), - args.scaling_mata_type.value(), args.matb->data_ptr(), args.scale_matb_ptr, args.ldb, args.matb->scalar_type(), args.scale_matb_dtype.value(), - args.scaling_matb_type.value(), bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), args.scale_result_ptr, args.result_ld, out_dtype_, - use_fast_accum); + use_fast_accum, + scaling_choice == ScalingType::RowWise); } return out; diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 30526c2a84826..31f36681bc3a4 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -785,7 +785,7 @@ def amax_to_scale( if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: - res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -806,20 +806,6 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): return amax_to_scale(amax, float8_dtype, x.dtype) -def tensor_to_scale_block( - x: torch.Tensor, - float8_dtype: torch.dtype, - block_outer: int, - block_inner: int, -) -> tuple[torch.Tensor, torch.Tensor]: - x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) - amax = x.abs().amax(dim=[1, 3], keepdim=True).float() - scale = torch.finfo(float8_dtype).max / amax - x = x.mul(scale).to(float8_dtype) - x = x.flatten(2, 3).flatten(0, 1) - scale = scale.flatten(2, 3).flatten(0, 1) - return x, scale - def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -828,17 +814,6 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) -def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: - x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1)) - y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1)) - x_fp32 = x.to(torch.float) / x_scale[:, None, :, None] - y_fp32 = y.to(torch.float) / y_scale[:, None, :, None] - x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1) - y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1) - out_fp32 = torch.mm(x_fp32, y_fp32) - - return out_fp32.to(out_dtype) - def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -1262,7 +1237,11 @@ def test_float8_error_messages(self, device) -> None: y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + "For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" + ), ): torch._scaled_mm( x_fp8, @@ -1273,7 +1252,11 @@ def test_float8_error_messages(self, device) -> None: ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + " For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" + ), ): torch._scaled_mm( x_fp8, @@ -1283,18 +1266,22 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N, 1), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + "Both scale_a and scale_b must be contiguous for RowWise scaling." + ), ): torch._scaled_mm( x_fp8, @@ -1359,58 +1346,6 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+") - @unittest.skipIf( - _get_torch_cuda_version() < (12, 9), - "cuBLAS blockwise scaling added in CUDA 12.9", - ) - @parametrize("output_dtype", [torch.bfloat16, torch.float32]) - @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) - def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block): - torch.manual_seed(42) - - x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3) - y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3) - - x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128) - y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128) - - # 1x128 blocks need scales to be outer-dim-major - if lhs_block == 1: - x_scales = x_scales.t().contiguous().t() - if rhs_block == 1: - y_scales = y_scales.t().contiguous().t() - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated_block( - x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype - ) - - cosine_sim = torch.nn.functional.cosine_similarity( - out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0 - ) - self.assertGreaterEqual(float(cosine_sim), 0.999) - - if output_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 6e-1, 7e-2 - else: - atol, rtol = 7e-1, 2e-3 - - self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - # One last check against the full-precision reference, to ensure we - # didn't mess up the scaling itself and made the test trivial. - cosine_sim = torch.nn.functional.cosine_similarity( - out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0 - ) - self.assertGreaterEqual(float(cosine_sim), 0.999) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) From 06a67a8948dac9d02f22b3e2591a43b60981cdb4 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:24:20 +0000 Subject: [PATCH 115/457] Fix sha256 for aotriton ROCm7.0 tarball (#158420) Fixes following issue of building PyTorch with ROCm7.0: ``` -- verifying file... file='/var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton-0.10b-manylinux_2_28_x86_64-rocm7.0-shared.tar.gz' -- SHA256 hash of /var/lib/jenkins/pytorch/build/aotriton_external-prefix/src/aotriton-0.10b-manylinux_2_28_x86_64-rocm7.0-shared.tar.gz does not match expected value expected: '7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838' actual: '1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b' -- Hash mismatch, removing... CMake Error at aotriton_external-prefix/src/aotriton_external-stamp/download-aotriton_external.cmake:163 (message): Each download failed! ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158420 Approved by: https://github.com/jeffdaily --- cmake/External/aotriton.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 8004b0f400a8d..8b380d24f6c8c 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -24,7 +24,7 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_SHA256_LIST "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 - "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0 + "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 ) set(__AOTRITON_Z "gz") From a23f4471b952d8cd630b860639e0aaa9be957d60 Mon Sep 17 00:00:00 2001 From: tvukovic-amd Date: Wed, 16 Jul 2025 15:31:40 +0000 Subject: [PATCH 116/457] [ROCm][Windows] Fix finding ROCm/HIP version (#156486) This commit fixes Windows build issue related to trying to use rocm-core (rocm-core doesn't exist on HIP SDK) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156486 Approved by: https://github.com/jeffdaily, https://github.com/stellaraccident --- cmake/public/LoadHIP.cmake | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index cae0ca62f2361..132f9670ff34f 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -93,19 +93,16 @@ if(HIP_FOUND) # hip (lower-case) package. Both are probed above and will be in # ROCM_INCLUDE_DIRS if available. find_file(ROCM_VERSION_HEADER_PATH - NAMES rocm-core/rocm_version.h + NAMES rocm-core/rocm_version.h hip/hip_version.h NO_DEFAULT_PATH PATHS ${ROCM_INCLUDE_DIRS} ) - set(ROCM_LIB_NAME "ROCM") - if(NOT ROCM_VERSION_HEADER_PATH) - find_file(ROCM_VERSION_HEADER_PATH - NAMES hip/hip_version.h - NO_DEFAULT_PATH - PATHS ${ROCM_INCLUDE_DIRS} - ) + if(ROCM_VERSION_HEADER_PATH MATCHES "rocm-core/rocm_version.h$") + set(ROCM_LIB_NAME "ROCM") + else() set(ROCM_LIB_NAME "HIP") endif() + if(NOT ROCM_VERSION_HEADER_PATH) message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}") endif() From a04a13c44908fe0ace4f76a228d045dbf5c015bc Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Mon, 14 Jul 2025 11:03:02 -0700 Subject: [PATCH 117/457] [BE][testing] Skip test_triton_interpret internally (#158260) Summary: Subprocesses in fbcode are tricky because of .par files. I'm thinking it's not an important enough test to get it running and skipping is fine. Test Plan: `buck test` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158260 Approved by: https://github.com/eellison --- test/inductor/test_cuda_repro.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 7f41613646d4f..b36df5058c27a 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1845,6 +1845,7 @@ def fn(x): self.assertEqual(graph.disable_cudagraphs_reason, None) self.assertEqual(graph.device_types, {"cuda"}) + @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_triton_interpret(self): import subprocess From 4b11428cb5b3d97f3068a2dc4c55cee6ddd41979 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 15 Jul 2025 12:54:51 -0700 Subject: [PATCH 118/457] [BE][testing] Skip test_repeated_masked_load internally (#158355) Summary: Test is failing internally because of the import from functorch.einops. _Maybe_ there's a way to get this dependence in the TARGETS file, but the obvious things didn't work. I'm wondering if this test is that important to have running in OSS and internally anyway? Test Plan: `buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:cuda_repro -- --exact 'caffe2/test/inductor:cuda_repro - test_repeated_masked_load (caffe2.test.inductor.test_cuda_repro.CudaReproTests)' --run-disabled` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158355 Approved by: https://github.com/eellison --- test/inductor/test_cuda_repro.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index b36df5058c27a..6007e3f3171f5 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -2099,6 +2099,7 @@ def get_input() -> torch.Tensor: self.assertIn("znumel", code) @xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032 + @unittest.skipIf(config.is_fbcode(), "Dependence on functorch.einops") def test_repeated_masked_load(self): target_size = (8, 2) mem_eff_temporal_upsampling_interp_chunks = 2 From a8b973673798ca79dfe616c9080415d09f9e990d Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Tue, 15 Jul 2025 12:52:45 -0700 Subject: [PATCH 119/457] [BE][testing] disable test_custom_op_square internally (#158367) Summary: test is failing with `ld.lld: error: unable to find library -laoti_custom_ops` Test Plan: `buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:test_aot_inductor_custom_ops -- --exact 'caffe2/test/inductor:test_aot_inductor_custom_ops - test_custom_op_square_cuda (caffe2.test.inductor.test_aot_inductor_custom_ops.AOTInductorTestABICompatibleCuda)' --run-disabled` Differential Revision: [D78364617](https://our.internmc.facebook.com/intern/diff/D78364617) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158367 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor_custom_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index 31de9ac4c71d0..fcbaeed297a33 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -416,6 +416,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): @skipIfXpu @skipIfRocm + @unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops") def test_custom_op_square(self) -> None: class Model(torch.nn.Module): def forward(self, x): From 4805a6ead6f1e7f32351056e2602be4e908f69b7 Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Wed, 16 Jul 2025 16:19:30 +0000 Subject: [PATCH 120/457] [aot][XPU] switch xpu to use consts cpp build. (#158425) Intel compiler is not support `format_consts_to_asm`, let's use `format_consts_to_cpp`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158425 Approved by: https://github.com/jansel --- torch/_inductor/codecache.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 78c47dcb082f6..8c95847d87f42 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1807,6 +1807,11 @@ def _compile_consts(consts: bytes, platform: str) -> str: else: raise RuntimeError(f"Unsupported platform: {platform}") + # Intel compiler failed to compile this manually constructed assembly file. + # Switch XPU to use consts cpp build. + if device_type == "xpu": + use_asm_build = False + is_large_consts = len(consts) > 1024 def format_consts_to_asm( @@ -1837,6 +1842,7 @@ def format_consts_to_asm( def format_consts_to_cpp( consts: bytes, align_bytes: int, symbol_prefix: str ) -> tuple[str, str]: + consts_size = len(consts) asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\ #define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\ #else\t\n\ @@ -1846,7 +1852,7 @@ def format_consts_to_cpp( ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n""" const_cpp = asan_attr const_cpp += f"alignas({align_bytes}) extern " - const_cpp += f"const unsigned char {symbol_prefix}_binary_constants_bin_start[] = {{\t\n" + const_cpp += f"const unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" count_bytes = 0 for c in consts: const_cpp += f"{c}, " @@ -1873,9 +1879,7 @@ def format_consts_to_cpp( ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( - # Intel compiler failed to compile this manually constructed assembly file. - # it is ok to use gcc to compile the .S to a .o and linked with Intel compiler . - device_type=device_type if device_type != "xpu" else "cpu", + device_type=device_type, aot_mode=graph.aot_mode, compile_only=True, use_relative_path=use_relative_path, From ff611d971fe5362a71c15109cf020d30e6c4b2b9 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 16 Jul 2025 17:17:34 +0000 Subject: [PATCH 121/457] [ROCm] check stream graph capture status in memcpy_and_sync inline function (#158165) Check for stream graph capture when using hipMemcpyWithStream. Fixes https://github.com/pytorch/pytorch/issues/155684, https://github.com/pytorch/pytorch/issues/155231 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158165 Approved by: https://github.com/jeffdaily --- c10/cuda/CUDAFunctions.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 2c7aa99feeb3a..9379e626d2cfd 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -90,8 +90,17 @@ C10_CUDA_API void __inline__ memcpy_and_sync( (*interp)->trace_gpu_stream_synchronization( c10::kCUDA, reinterpret_cast(stream)); } -#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) - C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); +#if USE_ROCM + // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of + // hipMemcpyWithStream which is a synchronous call. Thus, we add a check + // here explicitly. + hipStreamCaptureStatus captureStatus; + C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr)); + if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) { + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); + } else { + C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); + } #else C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); C10_CUDA_CHECK(cudaStreamSynchronize(stream)); From a369350065493109d1abfbb994695777ab11bcf4 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 16 Jul 2025 17:22:33 +0000 Subject: [PATCH 122/457] enable compiled autograd on CPU windows (#158432) compiled autograd on windows is disabled in PR #144707 because cuda windows cannot compile this code. However these code can be compiled on CPU. This PR enable these code on CPU windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158432 Approved by: https://github.com/jansel, https://github.com/xmfan Co-authored-by: Xu Han --- torch/csrc/dynamo/compiled_autograd.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index cba8158213c66..87689f34dfae6 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -1106,14 +1106,16 @@ struct IValuePacker { // That's what the TypePtr is for: it contains the information to do the // parsing. See torch::jit::toIValue for more information. static at::TypePtr packed_type() { -#ifdef _WIN32 +#if defined(_WIN32) +#if defined(USE_CUDA) || defined(USE_ROCM) // NB: the if-constexpr usage triggers compilation errors on Windows // with certain compiler settings // (see https://github.com/pytorch/pytorch/pull/144707 for examples). // It's not clear what the problem is, so we're going to ignore it for now. TORCH_CHECK_NOT_IMPLEMENTED( - false, "torch.compile not supported on Windows"); -#else + false, "torch.compile not supported on Windows GPU"); +#endif +#endif if constexpr (::std::is_same_v) { return at::TensorType::get(); } else if constexpr (::std::is_same_v) { @@ -1153,7 +1155,6 @@ struct IValuePacker { false, "IValuePacker not implemented for type"); return at::NoneType::get(); } -#endif } }; From 82b1c482929c359d8db6b844e4321db0a6477f0c Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 15 Jul 2025 11:18:23 -0700 Subject: [PATCH 123/457] [hop] add supports_higher_order_operators flag to TorchDispatchMode (#158077) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158077 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 2 +- test/test_python_dispatch.py | 33 ++++++++++++++++++++++++++++ torch/_ops.py | 27 ++++++++++++++++------- torch/utils/_python_dispatch.py | 6 +++++ 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index e7b6d426247c2..8bd86e55a8a27 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1186,7 +1186,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): pred = a.sum() > 0 with self.assertRaisesRegex( NotImplementedError, - "no rule registered for HOP cond and mode .*MyMode", + "no rule registered for HigherOrderOperator cond and mode .*MyMode", ): with MyMode(): res = cond_op(pred, torch.sin, torch.cos, (a,)) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index aef4cb0e69171..e0480ba6a6842 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2480,6 +2480,39 @@ def __torch_dispatch__(self, func, types, args, kwargs=None): self.assertEqual(res, t.a) self.assertIs(type(res), torch.Tensor) + def test_custom_dispatch_mode_supports_higher_order_operators(self): + class Mode(TorchDispatchMode): + supports_higher_order_operators = True + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + if func is torch.ops.higher_order.cond: + return torch.ones(3, 3) + return NotImplemented + + pred = torch.tensor(True) + x = torch.randn(1, 1) + with Mode(): + out = torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + self.assertEqual(out, torch.ones(3, 3)) + + def test_custom_dispatch_mode_not_supports_higher_order_operators(self): + class Mode(TorchDispatchMode): + supports_higher_order_operators = False + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + if func is torch.ops.higher_order.cond: + return torch.ones(3, 3) + return NotImplemented + + pred = torch.tensor(True) + x = torch.randn(1, 1) + with self.assertRaisesRegex( + NotImplementedError, + "There was no rule registered for HigherOrderOperator cond and mode", + ): + with Mode(): + torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + class TestPythonDispatcher(TestCase): def test_basic(self): diff --git a/torch/_ops.py b/torch/_ops.py index eeeb1dfc71130..9995aafb249a5 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -415,10 +415,19 @@ def check_overloaded(arg): # TODO(rzou): we should support torch_dispatch calling convention too. result = handler(mode, *args, **kwargs) else: - raise NotImplementedError( - f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " - f"We recommend filing an issue." - ) + if curr_mode.supports_higher_order_operators: + with _pop_mode_temporarily() as mode: + return curr_mode.__torch_dispatch__(self, [], args, kwargs) + else: + raise NotImplementedError( + f"There was no rule registered for HigherOrderOperator {self._name} and mode {curr_mode}." + f"Hint: set {curr_mode}'s supports_higher_order_operators to True." + f" This causes all higher order operators to pass through {curr_mode}'s __torch_dispatch__," + f" so handle them accordingly by" + f" adding support for HigerOrderOperators (in this case, {self._name}) in" + f" {curr_mode}.__torch_dispatch__ or" + f" returning NotImplemented when not supported." + ) if result is not NotImplemented: return result @@ -457,10 +466,12 @@ def check_overloaded(arg): # All handlers returned NotImplemented raise TypeError( - f"Multiple dispatch failed for {self._name}. There was no registered that " - f"did not return NotImplemented. Use HOP.py_impl to register some. " - f"Tried mode: {curr_mode}) and subclasses: " - f"{[type(a) for a in overloaded_args]}" + f"HigherOrderOperator '{self._name}' is not supported for the given input types. " + f"This typically happens when using custom tensor types or dispatch modes that don't " + f"have implementations for this operation.\n\n" + f"Current mode: {curr_mode}\n" + f"Input types: {[type(a).__name__ for a in overloaded_args]}\n\n" + f"To fix this, can add support for '{self._name}' in {curr_mode}'s __torch_dispatch__\n" ) functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 3fab41d82bc46..664994e6fe38f 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -68,6 +68,12 @@ class TorchDispatchMode: API self-referential (beware of infinite loops, in this case!) """ + # - When False, custom torch dispatch mode will error out explicitly when a hop + # is called under the mode. + # - When True, custom torch dispatch mode's __torch_dispatch__ will be triggered. + # Mode authors can implement how the mode interacts with higher order operators. + supports_higher_order_operators = False + def __init__(self, _dispatch_key=None): if _dispatch_key is not None: assert isinstance(_dispatch_key, torch._C.DispatchKey) From da05b7fb94fa6382c43e165a525a76d8ae62cadd Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 15 Jul 2025 11:18:24 -0700 Subject: [PATCH 124/457] [cond] add _FlopCounterMode support for cond (#158067) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158067 Approved by: https://github.com/zou3519 ghstack dependencies: #158077 --- test/dynamo/test_higher_order_ops.py | 97 ++++++++++++++++++++++++++++ torch/utils/flop_counter.py | 75 ++++++++++++++++++++- 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 8bd86e55a8a27..b9c1ff3a61fe9 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -7106,6 +7106,103 @@ def test_non_aliasing_util(self): ): _assert_tensors_nonaliasing(a, a) + def test_flop_counter_for_cond(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return torch.cond( + torch.tensor(True), + lambda x: self.linear(x), + lambda x: self.linear(self.linear(x)), + (x,), + ) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 256}, + "Mod": {torch.ops.aten.addmm: 256}, + "Mod.linear": {torch.ops.aten.addmm: 256}, + }, + ) + + def test_flop_counter_for_nested_cond(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + def true_branch(x): + # Nested cond inside true branch + return torch.cond( + torch.tensor(True), + lambda x: self.linear1(x), + lambda x: self.linear2(x), + (x,), + ) + + def false_branch(x): + return self.linear1(self.linear2(x)) + + return torch.cond(torch.tensor(True), true_branch, false_branch, (x,)) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 256}, + "Mod": {torch.ops.aten.addmm: 256}, + "Mod.linear1": {torch.ops.aten.addmm: 128}, + "Mod.linear2": {torch.ops.aten.addmm: 128}, + }, + ) + + def test_flop_counter_for_cond_unbalanced_branches(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + def true_branch(x): + return self.linear(x) + + def false_branch(x): + return x.clone() + + return torch.cond(torch.tensor(True), true_branch, false_branch, (x,)) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 128}, + "Mod": {torch.ops.aten.addmm: 128}, + "Mod.linear": {torch.ops.aten.addmm: 128}, + }, + ) + xfail_hops_compile = { # aot_eager diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index b2053439140bb..348e40eb62546 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -755,11 +755,81 @@ def _count_flops(self, func_packet, out, args, kwargs): return out - class _FlopCounterMode(TorchDispatchMode): + supports_higher_order_operators = True + def __init__(self, counter: FlopCounterMode): self.counter = counter + def _execute_with_isolated_flop_counting(self, branch_fn, operands): + """Execute a branch function and capture its FLOP counts without + affecting self.counter.flop_counts + + Args: + branch_fn: The branch function to execute + operands: Arguments to pass to the branch function + + Returns: + Tuple of (result, flop_counts) where result is the branch output + and flop_counts is a copy of the FLOP counts after execution + """ + import copy + checkpointed_flop_counts = copy.copy(self.counter.flop_counts) + with self: + result = branch_fn(*operands) + flop_counts = copy.copy(self.counter.flop_counts) + self.counter.flop_counts = checkpointed_flop_counts + return result, flop_counts + + def _handle_higher_order_ops(self, func, types, args, kwargs): + if func not in {torch.ops.higher_order.cond, }: + return NotImplemented + + # The flop counter for cond counts the upper bound of flops. + # For example, if a matmul is executed 2 times in true branch + # but only 1 time in the false branch, the flop counter will + # record the larger number of flops, i.e. 2 times. + if func is torch.ops.higher_order.cond: + + pred, true_branch, false_branch, operands = args + # Step 1: Count flops for true branch and false branch separately + true_out, true_flop_counts = self._execute_with_isolated_flop_counting( + true_branch, operands + ) + if true_out is NotImplemented: + return NotImplemented + + false_out, false_flop_counts = self._execute_with_isolated_flop_counting( + false_branch, operands + ) + if false_out is NotImplemented: + return NotImplemented + + # Step 2: merge flop counts + all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys()) + merged_flop_counts = {} + for outer_key in all_mod_keys: + true_func_counts = true_flop_counts[outer_key] + false_func_counts = false_flop_counts[outer_key] + + merged_func_counts = {} + all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys()) + + for func_key in all_func_keys: + true_val = true_func_counts.get(func_key, 0) + false_val = false_func_counts.get(func_key, 0) + merged_func_counts[func_key] = max(true_val, false_val) + + merged_flop_counts[outer_key] = merged_func_counts + + # Step 3: update the counter with merged counts + for outer_key, inner_dict in merged_flop_counts.items(): + self.counter.flop_counts[outer_key].update(inner_dict) + + # It doesn't matter which one we return since true_fn and false_fn return + # output with the same structure. + return true_out + def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} @@ -781,6 +851,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return NotImplemented + if isinstance(func, torch._ops.HigherOrderOperator): + return self._handle_higher_order_ops(func, types, args, kwargs) + # If we don't have func in flop_registry, see if it can decompose if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default: with self: From a26bf3892778ca7cc457c772a1f5194c11b6f33c Mon Sep 17 00:00:00 2001 From: Denghui Dong Date: Wed, 16 Jul 2025 18:00:05 +0000 Subject: [PATCH 125/457] Don't need to handle PyTrace_EXCEPTION in pyProfileFn (#154392) According to the [document](https://python.readthedocs.io/fr/stable/c-api/init.html#c.PyTrace_EXCEPTION) and [comment](https://github.com/python/cpython/blob/3.9/Modules/_lsprof.c#L407), we don't need to handle PyTrace_EXCEPTION in pyProfileFn. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154392 Approved by: https://github.com/sraikund16, https://github.com/cyyever --- torch/csrc/autograd/profiler_python.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a02e5dda5d992..69e8831936b0c 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -1252,7 +1252,6 @@ int PythonTracer::pyProfileFn( local_results.active_tracer_->recordCCall(local_results, frame, arg); break; - case PyTrace_EXCEPTION: case PyTrace_RETURN: local_results.exit_times_.emplace_back(c10::getApproximateTime()); break; From bc9091a524a1ebe4de16af4dd8f442db7d1cb138 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 16 Jul 2025 18:30:53 +0000 Subject: [PATCH 126/457] Fix indexing with multi-dimensional boolean mask (#158369) Fixes #71673 This fixes a bug in PyTorch indexing, that shows up when mixing multi-dimensional boolean masks with other forms of indexing. Examples: ```python >>> import torch >>> x = torch.ones([2, 2, 3]) >>> m = torch.tensor(((True, False), (False, False))) # (2x2 boolean mask) >>> x[m].shape # this works fine (the boolean mask acts on the 2x2 subspace selecting one row) torch.Size([1, 3]) >>> x[m, 0] # this should produce a tensor of shape (1,) Traceback (most recent call last): File "", line 1, in IndexError: The shape of the mask [2, 2] at index 1 does not match the shape of the indexed tensor [2, 3] at index 1 >>> x[m, ::2] # this should produce a tensor of shape (1, 2) Traceback (most recent call last): File "", line 1, in IndexError: The shape of the mask [2, 2] at index 1 does not match the shape of the indexed tensor [2, 1, 3] at index 1 >>> x[m, None] # this should produce a tensor of shape (1, 1, 3) Traceback (most recent call last): File "", line 1, in IndexError: The shape of the mask [2, 2] at index 1 does not match the shape of the indexed tensor [2, 1, 2, 3] at index 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158369 Approved by: https://github.com/ngimel --- aten/src/ATen/TensorIndexing.h | 35 ++++++++++++++++++--- test/test_indexing.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index d9d8554abc795..3648862c12241 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -315,10 +315,17 @@ inline void recordTensorIndex( const Tensor& tensor, std::vector& outIndices, int64_t* dim_ptr) { - // TODO: check scalarType - outIndices.resize(*dim_ptr + 1); - outIndices[*dim_ptr] = tensor; - (*dim_ptr)++; + if (outIndices.empty()) { + outIndices.resize(*dim_ptr + 1); + outIndices[*dim_ptr] = tensor; + } else { + outIndices.push_back(tensor); + } + if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { + *dim_ptr += tensor.dim(); + } else { + *dim_ptr += 1; + } } inline c10::List<::std::optional> typeConvertIndices( @@ -458,13 +465,23 @@ inline Tensor handleDimInMultiDimIndexing( original_tensor_device, prev_dim_result_sizes); (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } return result; } else if (index.is_ellipsis()) { - (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); + auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr; + (*dim_ptr) += ellipsis_ndims; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + ellipsis_ndims); + } return prev_dim_result; } else if (index.is_none()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } return result; } else if (index.is_boolean()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); @@ -560,6 +577,10 @@ inline Tensor applySlicing( inline Tensor dispatch_index( const Tensor& self, std::vector&& indices) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } return self.index(impl::typeConvertIndices(self, std::move(indices))); } @@ -567,6 +588,10 @@ inline Tensor dispatch_index_put_( Tensor& self, std::vector&& indices, const Tensor& value) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } return self.index_put_( impl::typeConvertIndices(self, std::move(indices)), value); } diff --git a/test/test_indexing.py b/test/test_indexing.py index fa7de92b98290..a57f658025a3b 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -908,6 +908,63 @@ def test_multiple_bool_indices(self, device): mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) + def test_multi_dimensional_bool_mask(self, device): + x = torch.randn(2, 2, 3, device=device) + b = ((True, False), (False, False)) + m = torch.tensor(b, dtype=torch.bool, device=device) + z = torch.tensor(0) + t = torch.tensor(True) + f = torch.tensor(False) + + # Using boolean sequence + self.assertEqual(x[b,].shape, (1, 3)) + self.assertEqual(x[b, ::2].shape, (1, 2)) + self.assertEqual(x[b, None].shape, (1, 1, 3)) + self.assertEqual(x[b, 0].shape, (1,)) + self.assertEqual(x[b, z].shape, (1,)) + self.assertEqual(x[b, True].shape, (1, 3)) + self.assertEqual(x[b, True, True, True, True].shape, (1, 3)) + self.assertEqual(x[b, False].shape, (0, 3)) + self.assertEqual(x[b, True, True, False, True].shape, (0, 3)) + self.assertEqual(x[b, t].shape, (1, 3)) + self.assertEqual(x[b, f].shape, (0, 3)) + + # Using boolean tensor + self.assertEqual(x[m].shape, (1, 3)) + self.assertEqual(x[m, ::2].shape, (1, 2)) + self.assertEqual(x[m, None].shape, (1, 1, 3)) + self.assertEqual(x[m, 0].shape, (1,)) + self.assertEqual(x[m, z].shape, (1,)) + self.assertEqual(x[m, True].shape, (1, 3)) + self.assertEqual(x[m, True, True, True, True].shape, (1, 3)) + self.assertEqual(x[m, False].shape, (0, 3)) + self.assertEqual(x[m, True, True, False, True].shape, (0, 3)) + self.assertEqual(x[m, t].shape, (1, 3)) + self.assertEqual(x[m, f].shape, (0, 3)) + + # Boolean mask in the middle of indices array + x = torch.randn(3, 2, 2, 5, device=device) + self.assertEqual(x[:, m, :].shape, (3, 1, 5)) + self.assertEqual(x[0, m, ::2].shape, (1, 3)) + self.assertEqual(x[..., m, ::2].shape, (3, 1, 3)) + self.assertEqual(x[None, ..., m, ::2].shape, (1, 3, 1, 3)) + + def test_bool_mask_assignment(self, device): + v = torch.tensor([[1, 2], [3, 4]], device=device) + mask = torch.tensor([1, 0], dtype=torch.bool, device=device) + v[mask, :] = 0 + self.assertEqual(v, torch.tensor([[0, 0], [3, 4]], device=device)) + + v = torch.tensor([[1, 2], [3, 4]], device=device) + v[:, mask] = 0 + self.assertEqual(v, torch.tensor([[0, 2], [0, 4]], device=device)) + + def test_multi_dimensional_bool_mask_assignment(self, device): + v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device) + mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool, device=device) + v[:, mask, :] = 0 + self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) + def test_byte_mask(self, device): v = torch.randn(5, 7, 3, device=device) mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) From 4d055982e38f59fdb2a4c9d8855e58548bc42c12 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 16 Jul 2025 18:46:09 +0000 Subject: [PATCH 127/457] recovering node source from dict (#158373) Summary: this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde. Test Plan: ci Rollback Plan: Differential Revision: D78363882 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158373 Approved by: https://github.com/yushangdi --- test/fx/test_fx_traceback.py | 28 ++++++++++++++++++----- torch/fx/traceback.py | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index f02bc5a2e1592..05369d17078ba 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -32,6 +32,8 @@ def test_node_source(self): dummy_source_dict, ) + self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict())) + # Dummy node node = torch.fx.Node( graph=torch.fx.Graph(), @@ -179,14 +181,28 @@ def forward(self, x): if node_name_1 in same_ancestor_nodes else None, }: - self.assertTrue( - node_name_to_from_node[node_name_1] - == node_name_to_from_node[node_name_2] + self.assertEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], ) else: - self.assertTrue( - node_name_to_from_node[node_name_1] - != node_name_to_from_node[node_name_2] + self.assertNotEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertNotEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], ) gm = ep.module() diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 836b41d661859..bcf759c3db4c5 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -152,6 +152,49 @@ def _make_hashable(obj): return hash(_make_hashable(self.to_dict())) + @classmethod + def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: + """ + Recursively deserialize from_node metadata from dictionary data. + It is used to deserialize the from_node field from serialized metadata. + Please use contructor NodeSource(node, ...) to create a NodeSource object. + """ + if d is None: + return None + + assert isinstance(d, dict), f"Expected a dict, got {type(d)}" + + # Create a NodeSource object directly without going through the constructor + # to avoid issues with graph ID and node creation + node_source = NodeSource.__new__(NodeSource) + + # Set the basic attributes + node_source.pass_name = d.get("pass_name", "") + + # Parse action string back to NodeSourceAction enum list + action_str = d.get("action", "") + actions = [] + if action_str: + for action_name in action_str.split("+"): + if action_name.upper() == "CREATE": + actions.append(NodeSourceAction.CREATE) + elif action_name.upper() == "REPLACE": + actions.append(NodeSourceAction.REPLACE) + node_source.action = actions + + # Create the NodeInfo object directly + if "name" in d and "target" in d and "graph_id" in d: + node_info = NodeSource.NodeInfo( + d.get("name", ""), d.get("target", ""), d.get("graph_id", -1) + ) + node_source.node_info = node_info + else: + node_source.node_info = None + + # Recursively deserialize nested from_node + node_source.from_node = [cls._from_dict(fn) for fn in d.get("from_node", [])] + return node_source + @compatibility(is_backward_compatible=False) @contextmanager From b40f48d19186fcd9543f8d2217017eff8b723a9f Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 16 Jul 2025 07:39:25 -0700 Subject: [PATCH 128/457] Move the rest of c10/macros/Export.h (#158358) Differential Revision: [D78356975](https://our.internmc.facebook.com/intern/diff/D78356975/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158358 Approved by: https://github.com/swolchok --- c10/macros/Export.h | 77 -------------------------------- torch/headeronly/macros/Export.h | 76 +++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 77 deletions(-) diff --git a/c10/macros/Export.h b/c10/macros/Export.h index 3d91266102613..1b8a6811c53f5 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -1,78 +1 @@ -#ifndef C10_MACROS_EXPORT_H_ -#define C10_MACROS_EXPORT_H_ - -#ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include -#endif // C10_USING_CUSTOM_GENERATED_MACROS - #include - -// This one is being used by libtorch.so -#ifdef CAFFE2_BUILD_MAIN_LIB -#define TORCH_API C10_EXPORT -#else -#define TORCH_API C10_IMPORT -#endif - -// You may be wondering: Whose brilliant idea was it to split torch_cuda into -// two pieces with confusing names? -// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we -// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker -// issues when linking big binaries. -// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: -// (1) Stop supporting so many GPU architectures -// (2) Do something else -// We chose #2 and decided to split the behemoth that was torch_cuda into two -// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) -// and the other that had..well..everything else (torch_cuda_cpp). The idea was -// this: instead of linking our static libraries (like the hefty -// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky -// relocation marker issues, we could link our static libraries to a smaller -// part of torch_cuda (torch_cuda_cpp) and avoid the issues. - -// libtorch_cuda_cu.so -#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB -#define TORCH_CUDA_CU_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CU_API C10_IMPORT -#endif - -// libtorch_cuda_cpp.so -#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB -#define TORCH_CUDA_CPP_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CPP_API C10_IMPORT -#endif - -// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the -// same api) -#ifdef TORCH_CUDA_BUILD_MAIN_LIB -#define TORCH_CUDA_CPP_API C10_EXPORT -#define TORCH_CUDA_CU_API C10_EXPORT -#elif !defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CPP_API C10_IMPORT -#define TORCH_CUDA_CU_API C10_IMPORT -#endif - -#if defined(TORCH_HIP_BUILD_MAIN_LIB) -#define TORCH_HIP_CPP_API C10_EXPORT -#define TORCH_HIP_API C10_EXPORT -#else -#define TORCH_HIP_CPP_API C10_IMPORT -#define TORCH_HIP_API C10_IMPORT -#endif - -#if defined(TORCH_XPU_BUILD_MAIN_LIB) -#define TORCH_XPU_API C10_EXPORT -#else -#define TORCH_XPU_API C10_IMPORT -#endif - -// Enums only need to be exported on windows for non-CUDA files -#if defined(_WIN32) && defined(__CUDACC__) -#define C10_API_ENUM C10_API -#else -#define C10_API_ENUM -#endif - -#endif // C10_MACROS_EXPORT_H_ diff --git a/torch/headeronly/macros/Export.h b/torch/headeronly/macros/Export.h index 183aeab563445..8c4e207d0dada 100644 --- a/torch/headeronly/macros/Export.h +++ b/torch/headeronly/macros/Export.h @@ -1,5 +1,12 @@ #pragma once +#ifndef C10_MACROS_EXPORT_H_ +#define C10_MACROS_EXPORT_H_ + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + /* Header file to define the common scaffolding for exported symbols. * * Export is by itself a quite tricky situation to deal with, and if you are @@ -85,3 +92,72 @@ #else #define C10_API C10_IMPORT #endif + +// This one is being used by libtorch.so +#ifdef CAFFE2_BUILD_MAIN_LIB +#define TORCH_API C10_EXPORT +#else +#define TORCH_API C10_IMPORT +#endif + +// You may be wondering: Whose brilliant idea was it to split torch_cuda into +// two pieces with confusing names? +// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we +// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker +// issues when linking big binaries. +// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: +// (1) Stop supporting so many GPU architectures +// (2) Do something else +// We chose #2 and decided to split the behemoth that was torch_cuda into two +// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) +// and the other that had..well..everything else (torch_cuda_cpp). The idea was +// this: instead of linking our static libraries (like the hefty +// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky +// relocation marker issues, we could link our static libraries to a smaller +// part of torch_cuda (torch_cuda_cpp) and avoid the issues. + +// libtorch_cuda_cu.so +#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB +#define TORCH_CUDA_CU_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +// libtorch_cuda_cpp.so +#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#elif defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#endif + +// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the +// same api) +#ifdef TORCH_CUDA_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#define TORCH_CUDA_CU_API C10_EXPORT +#elif !defined(BUILD_SPLIT_CUDA) +#define TORCH_CUDA_CPP_API C10_IMPORT +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +#if defined(TORCH_HIP_BUILD_MAIN_LIB) +#define TORCH_HIP_CPP_API C10_EXPORT +#define TORCH_HIP_API C10_EXPORT +#else +#define TORCH_HIP_CPP_API C10_IMPORT +#define TORCH_HIP_API C10_IMPORT +#endif + +#if defined(TORCH_XPU_BUILD_MAIN_LIB) +#define TORCH_XPU_API C10_EXPORT +#else +#define TORCH_XPU_API C10_IMPORT +#endif + +// Enums only need to be exported on windows for non-CUDA files +#if defined(_WIN32) && defined(__CUDACC__) +#define C10_API_ENUM C10_API +#else +#define C10_API_ENUM +#endif +#endif // C10_MACROS_EXPORT_H_ From 2b0f9b1f6172a0d5817c7ac7406200897311da5f Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 16 Jul 2025 07:39:27 -0700 Subject: [PATCH 129/457] Move c10/macros/Macros.h to headeronly (#158365) ^ Differential Revision: [D78361893](https://our.internmc.facebook.com/intern/diff/D78361893/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158365 Approved by: https://github.com/swolchok ghstack dependencies: #158358 --- .lintrunner.toml | 4 +- c10/macros/Macros.h | 549 +----------------------------- torch/headeronly/macros/Macros.h | 548 +++++++++++++++++++++++++++++ torch/headeronly/macros/build.bzl | 1 + 4 files changed, 552 insertions(+), 550 deletions(-) create mode 100644 torch/headeronly/macros/Macros.h diff --git a/.lintrunner.toml b/.lintrunner.toml index 7e9b7ebd5d2c1..4da6616e08cff 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -500,7 +500,7 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ - 'c10/macros/Macros.h', + 'torch/headeronly/macros/Macros.h', ] command = [ 'python3', @@ -523,7 +523,7 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ - 'c10/macros/Macros.h', + 'torch/headeronly/macros/Macros.h', ] command = [ 'python3', diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 55a79ee67430c..87ebc4f422c4c 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -1,548 +1 @@ -#ifndef C10_MACROS_MACROS_H_ -#define C10_MACROS_MACROS_H_ -#include - -/* Main entry for c10/macros. - * - * In your code, include c10/macros/Macros.h directly, instead of individual - * files in this folder. - */ - -// For build systems that do not directly depend on CMake and directly build -// from the source directory (such as Buck), one may not have a cmake_macros.h -// file at all. In this case, the build system is responsible for providing -// correct macro definitions corresponding to the cmake_macros.h.in file. -// -// In such scenarios, one should define the macro -// C10_USING_CUSTOM_GENERATED_MACROS -// to inform this header that it does not need to include the cmake_macros.h -// file. - -#ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include -#endif // C10_USING_CUSTOM_GENERATED_MACROS - -#include - -#if defined(__clang__) -#define __ubsan_ignore_float_divide_by_zero__ \ - __attribute__((no_sanitize("float-divide-by-zero"))) -#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) -#define __ubsan_ignore_signed_int_overflow__ \ - __attribute__((no_sanitize("signed-integer-overflow"))) -#define __ubsan_ignore_pointer_overflow__ \ - __attribute__((no_sanitize("pointer-overflow"))) -#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) -#define __ubsan_ignore_float_cast_overflow__ \ - __attribute__((no_sanitize("float-cast-overflow"))) -#else -#define __ubsan_ignore_float_divide_by_zero__ -#define __ubsan_ignore_undefined__ -#define __ubsan_ignore_signed_int_overflow__ -#define __ubsan_ignore_pointer_overflow__ -#define __ubsan_ignore_function__ -#define __ubsan_ignore_float_cast_overflow__ -#endif - -// Detect address sanitizer as some stuff doesn't work with it -#undef C10_ASAN_ENABLED - -// for clang -#if defined(__has_feature) -#if ((__has_feature(address_sanitizer))) -#define C10_ASAN_ENABLED 1 -#endif -#endif - -// for gcc -#if defined(__SANITIZE_ADDRESS__) -#if __SANITIZE_ADDRESS__ -#if !defined(C10_ASAN_ENABLED) -#define C10_ASAN_ENABLED 1 -#endif -#endif -#endif - -#if !defined(C10_ASAN_ENABLED) -#define C10_ASAN_ENABLED 0 -#endif - -// Detect undefined-behavior sanitizer (UBSAN) -#undef C10_UBSAN_ENABLED - -// for clang or gcc >= 14 -// NB: gcc 14 adds support for Clang's __has_feature -// https://gcc.gnu.org/gcc-14/changes.html -// gcc < 14 doesn't have a macro for UBSAN -// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) -// https://github.com/google/sanitizers/issues/765 -#if defined(__has_feature) -#if ((__has_feature(undefined_behavior_sanitizer))) -#define C10_UBSAN_ENABLED 1 -#endif -#endif - -#if !defined(C10_UBSAN_ENABLED) -#define C10_UBSAN_ENABLED 0 -#endif - -// Disable the copy and assignment operator for a class. Note that this will -// disable the usage of the class in std containers. -#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ - classname(const classname&) = delete; \ - classname& operator=(const classname&) = delete - -#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 -#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) - -#define C10_MACRO_EXPAND(args) args - -#define C10_STRINGIZE_IMPL(x) #x -#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) - -/** - * C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with - * str and ends with a unique number. - */ -#ifdef __COUNTER__ -#define C10_UID __COUNTER__ -#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) -#else -#define C10_UID __LINE__ -#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) -#endif - -#ifdef __has_cpp_attribute -#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) -#else -#define C10_HAS_CPP_ATTRIBUTE(x) (0) -#endif - -#ifndef FBCODE_CAFFE2 -/// DEPRECATED: Warn if a type or return value is discarded. -#define C10_NODISCARD [[nodiscard]] - -/// DEPRECATED: Suppress an unused variable. -#define C10_UNUSED [[maybe_unused]] -#endif - -#if !defined(__has_attribute) -#define __has_attribute(x) 0 -#endif - -// Direct port of LLVM_ATTRIBUTE_USED. -#if __has_attribute(used) -#define C10_USED __attribute__((__used__)) -#else -#define C10_USED -#endif - -#define C10_RESTRICT __restrict - -// Simply define the namespace, in case a dependent library want to refer to -// the c10 namespace but not any nontrivial files. -namespace c10 {} -namespace c10::cuda {} -namespace c10::hip {} -namespace c10::xpu {} - -// Since C10 is the core library for caffe2 (and aten), we will simply reroute -// all abstractions defined in c10 to be available in caffe2 as well. -// This is only for backwards compatibility. Please use the symbols from the -// c10 namespace where possible. -namespace caffe2 { -using namespace c10; -} -namespace at { -using namespace c10; -} -namespace at::cuda { -using namespace c10::cuda; -} // namespace at::cuda - -// WARNING!!! THIS IS A GIANT HACK!!! -// This line means you cannot simultaneously include c10/hip -// and c10/cuda and then use them from the at::cuda namespace. -// This is true in practice, because HIPIFY works inplace on -// files in ATen/cuda, so it assumes that c10::hip is available -// from at::cuda. This namespace makes that happen. When -// HIPIFY is no longer out-of-place, we can switch the cuda -// here to hip and everyone is happy. -namespace at::cuda { -using namespace c10::hip; -} // namespace at::cuda - -namespace at::xpu { -using namespace c10::xpu; -} // namespace at::xpu - -// C10_LIKELY/C10_UNLIKELY -// -// These macros provide parentheses, so you can use these macros as: -// -// if C10_LIKELY(some_expr) { -// ... -// } -// -// NB: static_cast to boolean is mandatory in C++, because __builtin_expect -// takes a long argument, which means you may trigger the wrong conversion -// without it. -// -#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) -#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) -#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) -#else -#define C10_LIKELY(expr) (expr) -#define C10_UNLIKELY(expr) (expr) -#endif - -/// C10_NOINLINE - Functions whose declaration is annotated with this will not -/// be inlined. -#ifdef __GNUC__ -#define C10_NOINLINE __attribute__((noinline)) -#elif _MSC_VER -#define C10_NOINLINE __declspec(noinline) -#else -#define C10_NOINLINE -#endif - -#if defined(_MSC_VER) -#define C10_ALWAYS_INLINE __forceinline -#elif __has_attribute(always_inline) || defined(__GNUC__) -#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline -#else -#define C10_ALWAYS_INLINE inline -#endif - -// Unlike C10_ALWAYS_INLINE, C10_ALWAYS_INLINE_ATTRIBUTE can be used -// on a lambda. -#if defined(_MSC_VER) -// MSVC 14.39 is reasonably recent and doesn't like -// [[msvc::forceinline]] on a lambda, so don't try to use it. -#define C10_ALWAYS_INLINE_ATTRIBUTE -#elif __has_attribute(always_inline) || defined(__GNUC__) -#define C10_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) -#else -#define C10_ALWAYS_INLINE_ATTRIBUTE -#endif - -#if defined(_MSC_VER) -#define C10_ATTR_VISIBILITY_HIDDEN -#elif defined(__GNUC__) -#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) -#else -#define C10_ATTR_VISIBILITY_HIDDEN -#endif - -#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN - -#include - -#ifdef __HIPCC__ -// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. -// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. -// See https://github.com/ROCm/hip/issues/441 -#include -#endif - -#if defined(__CUDACC__) || defined(__HIPCC__) -// Designates functions callable from the host (CPU) and the device (GPU) -#define C10_HOST_DEVICE __host__ __device__ -#define C10_DEVICE __device__ -#define C10_HOST __host__ -// constants from -// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) -// The maximum number of threads per multiprocessor is 1024 for Turing -// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and -// 2048 for all other architectures. You'll get warnings if you exceed these -// constants. Hence, the following macros adjust the input values from the user -// to resolve potential warnings. -#if __CUDA_ARCH__ == 750 -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; -#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; -#else -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; -#endif -// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently -constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; -// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block -// size. 256 is a good number for this fallback and should give good occupancy -// and versatility across all architectures. -constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; -// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it -// turns out that although __launch_bounds__ can take constexpr, it -// can't take a constexpr that has anything to do with templates. -// Currently we use launch_bounds that depend on template arguments in -// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK -// and C10_MIN_BLOCKS_PER_SM are kept as macros. -// Suppose you were planning to write __launch_bounds__(a, b), based on your -// performance tuning on a modern GPU. Instead, you should write -// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), -// which will also properly respect limits on old architectures. -#define C10_MAX_THREADS_PER_BLOCK(val) \ - (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ - : CUDA_THREADS_PER_BLOCK_FALLBACK) -#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ - ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ - ? (blocks_per_sm) \ - : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ - (threads_per_block)))) -// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ -#define C10_LAUNCH_BOUNDS_0 \ - __launch_bounds__( \ - 256, 4) // default launch bounds that should give good occupancy and - // versatility across all architectures. -#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ - __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) -#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ - __launch_bounds__( \ - (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ - (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) -#else -#define C10_HOST_DEVICE -#define C10_HOST -#define C10_DEVICE -#endif - -#if defined(USE_ROCM) -#define C10_HIP_HOST_DEVICE __host__ __device__ -#else -#define C10_HIP_HOST_DEVICE -#endif - -#if defined(USE_ROCM) -// C10_WARP_SIZE is only allowed for device code. -// Host code _must_ use at::cuda::warp_size() -// HIP header used to define warpSize as a constexpr that was either 32 or 64 -// depending on the target device, and then always set it to 64 for host code. -// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we -// set it to something unreasonable to trigger obvious host code errors. -#if defined(__HIP_DEVICE_COMPILE__) -#if defined(__GFX9__) -static constexpr int C10_WARP_SIZE = 64; -#else // __GFX9__ -static constexpr int C10_WARP_SIZE = 32; -#endif // __GFX9__ -#else -static constexpr int C10_WARP_SIZE = 1; -#endif // __HIP_DEVICE_COMPILE__ -#else -#define C10_WARP_SIZE 32 -#endif - -#if defined(_MSC_VER) && _MSC_VER <= 1900 -#define __func__ __FUNCTION__ -#endif - -// CUDA_KERNEL_ASSERT checks the assertion -// even when NDEBUG is defined. This is useful for important assertions in CUDA -// code that would otherwise be suppressed when building Release. -#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) -// Those platforms do not support assert() -#define CUDA_KERNEL_ASSERT(cond) -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) -#define SYCL_KERNEL_ASSERT(cond) -#elif defined(_MSC_VER) -#if defined(NDEBUG) -extern "C" { -C10_IMPORT -#if defined(__SYCL_DEVICE_ONLY__) -extern SYCL_EXTERNAL void _wassert( - const wchar_t* wexpr, - const wchar_t* wfile, - unsigned line); -#else -#if defined(__CUDA_ARCH__) -__host__ __device__ -#endif // __CUDA_ARCH__ - void - _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); -#endif // __SYCL_DEVICE_ONLY__ -} -#endif // NDEBUG -#define CUDA_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -// TODO: This doesn't assert the message because I (chilli) couldn't figure out -// a nice way to convert a char* to a wchar_t* -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -#else // __APPLE__, _MSC_VER -#if defined(NDEBUG) -extern "C" { -#if defined(__SYCL_DEVICE_ONLY__) -extern SYCL_EXTERNAL void __assert_fail( - const char* expr, - const char* file, - unsigned int line, - const char* func); -#else // __SYCL_DEVICE_ONLY__ -#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) -// CUDA supports __assert_fail function which are common for both device -// and host side code. -__host__ __device__ -#endif - - // This forward declaration matching the declaration of __assert_fail - // exactly how it is in glibc in case parts of the program are compiled with - // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' - // error. Note: On ROCm - this declaration serves for host side compilation. - void - __assert_fail( - const char* assertion, - const char* file, - unsigned int line, - const char* function) noexcept __attribute__((__noreturn__)); - -#endif // __SYCL_DEVICE_ONLY__ -} -#endif // NDEBUG -// ROCm disables kernel assert by default for performance considerations. -// Though ROCm supports __assert_fail, it uses kernel printf which has -// a non-negligible performance impact even if the assert condition is -// never triggered. We choose to use abort() instead which will still -// terminate the application but without a more useful error message. -#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) -#define CUDA_KERNEL_ASSERT(cond) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#else -#define CUDA_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - #cond, __FILE__, static_cast(__LINE__), __func__); \ - } -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - msg, __FILE__, static_cast(__LINE__), __func__); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - #cond, __FILE__, static_cast(__LINE__), __func__); \ - } -#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM -#endif // __APPLE__ - -#ifdef __APPLE__ -#include -#endif - -#if defined(__ANDROID__) -#define C10_ANDROID 1 -#define C10_MOBILE 1 -#elif ( \ - defined(__APPLE__) && \ - (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) -#define C10_IOS 1 -#define C10_MOBILE 1 -#endif // ANDROID / IOS - -#if defined(C10_MOBILE) && C10_MOBILE -#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline -#else -#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE -#endif - -#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) -#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr -#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr - -#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static constexpr const char field[] = val; -#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) -#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) - -#ifndef HAS_DEMANGLE -#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) -#define HAS_DEMANGLE 0 -#elif defined(__APPLE__) && \ - (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) -#define HAS_DEMANGLE 0 -#else -#define HAS_DEMANGLE 1 -#endif -#endif // HAS_DEMANGLE - -#define _C10_PRAGMA__(string) _Pragma(#string) -#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) - -#ifdef __clang__ -#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") -#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") -#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ - _C10_PRAGMA_(clang diagnostic ignored flag) -#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) -#else -#define C10_CLANG_DIAGNOSTIC_PUSH() -#define C10_CLANG_DIAGNOSTIC_POP() -#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) -#define C10_CLANG_HAS_WARNING(flag) 0 -#endif - -#ifdef __clang__ - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ - _C10_PRAGMA_(clang diagnostic push) \ - _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ - _C10_PRAGMA_(clang diagnostic ignored warning) - -#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) - -#elif __GNUC__ - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ - _C10_PRAGMA_(GCC diagnostic push) \ - _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ - _C10_PRAGMA_(GCC diagnostic ignored warning) - -#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) - -#else - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) -#define C10_DIAGNOSTIC_POP() - -#endif - -// This macro is used to find older C++ compilers -// that don't support move optimization for return values. - -#if (defined(__GNUC__) && __GNUC__ < 13) || \ - (defined(__clang_major__) && __clang_major__ < 13) -#define C10_RETURN_MOVE_IF_OLD_COMPILER 1 -#else -#define C10_RETURN_MOVE_IF_OLD_COMPILER 0 -#endif - -#endif // C10_MACROS_MACROS_H_ +#include diff --git a/torch/headeronly/macros/Macros.h b/torch/headeronly/macros/Macros.h new file mode 100644 index 0000000000000..0c02cce309dc8 --- /dev/null +++ b/torch/headeronly/macros/Macros.h @@ -0,0 +1,548 @@ +#ifndef C10_MACROS_MACROS_H_ +#define C10_MACROS_MACROS_H_ +#include + +/* Main entry for torch/headeronly/macros (used to be c10/macros). + * + * In your code, include torch/headeronly/macros/Macros.h directly, instead of + * individual files in this folder. + */ + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#include + +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ +#endif + +// Detect address sanitizer as some stuff doesn't work with it +#undef C10_ASAN_ENABLED + +// for clang +#if defined(__has_feature) +#if ((__has_feature(address_sanitizer))) +#define C10_ASAN_ENABLED 1 +#endif +#endif + +// for gcc +#if defined(__SANITIZE_ADDRESS__) +#if __SANITIZE_ADDRESS__ +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 1 +#endif +#endif +#endif + +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 0 +#endif + +// Detect undefined-behavior sanitizer (UBSAN) +#undef C10_UBSAN_ENABLED + +// for clang or gcc >= 14 +// NB: gcc 14 adds support for Clang's __has_feature +// https://gcc.gnu.org/gcc-14/changes.html +// gcc < 14 doesn't have a macro for UBSAN +// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) +// https://github.com/google/sanitizers/issues/765 +#if defined(__has_feature) +#if ((__has_feature(undefined_behavior_sanitizer))) +#define C10_UBSAN_ENABLED 1 +#endif +#endif + +#if !defined(C10_UBSAN_ENABLED) +#define C10_UBSAN_ENABLED 0 +#endif + +// Disable the copy and assignment operator for a class. Note that this will +// disable the usage of the class in std containers. +#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete + +#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 +#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) + +#define C10_MACRO_EXPAND(args) args + +#define C10_STRINGIZE_IMPL(x) #x +#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) + +/** + * C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with + * str and ends with a unique number. + */ +#ifdef __COUNTER__ +#define C10_UID __COUNTER__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) +#else +#define C10_UID __LINE__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) +#endif + +#ifdef __has_cpp_attribute +#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +#define C10_HAS_CPP_ATTRIBUTE(x) (0) +#endif + +#ifndef FBCODE_CAFFE2 +/// DEPRECATED: Warn if a type or return value is discarded. +#define C10_NODISCARD [[nodiscard]] + +/// DEPRECATED: Suppress an unused variable. +#define C10_UNUSED [[maybe_unused]] +#endif + +#if !defined(__has_attribute) +#define __has_attribute(x) 0 +#endif + +// Direct port of LLVM_ATTRIBUTE_USED. +#if __has_attribute(used) +#define C10_USED __attribute__((__used__)) +#else +#define C10_USED +#endif + +#define C10_RESTRICT __restrict + +// Simply define the namespace, in case a dependent library want to refer to +// the c10 namespace but not any nontrivial files. +namespace c10 {} +namespace c10::cuda {} +namespace c10::hip {} +namespace c10::xpu {} + +// Since C10 is the core library for caffe2 (and aten), we will simply reroute +// all abstractions defined in c10 to be available in caffe2 as well. +// This is only for backwards compatibility. Please use the symbols from the +// c10 namespace where possible. +namespace caffe2 { +using namespace c10; +} +namespace at { +using namespace c10; +} +namespace at::cuda { +using namespace c10::cuda; +} // namespace at::cuda + +// WARNING!!! THIS IS A GIANT HACK!!! +// This line means you cannot simultaneously include c10/hip +// and c10/cuda and then use them from the at::cuda namespace. +// This is true in practice, because HIPIFY works inplace on +// files in ATen/cuda, so it assumes that c10::hip is available +// from at::cuda. This namespace makes that happen. When +// HIPIFY is no longer out-of-place, we can switch the cuda +// here to hip and everyone is happy. +namespace at::cuda { +using namespace c10::hip; +} // namespace at::cuda + +namespace at::xpu { +using namespace c10::xpu; +} // namespace at::xpu + +// C10_LIKELY/C10_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if C10_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define C10_LIKELY(expr) (expr) +#define C10_UNLIKELY(expr) (expr) +#endif + +/// C10_NOINLINE - Functions whose declaration is annotated with this will not +/// be inlined. +#ifdef __GNUC__ +#define C10_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define C10_NOINLINE __declspec(noinline) +#else +#define C10_NOINLINE +#endif + +#if defined(_MSC_VER) +#define C10_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define C10_ALWAYS_INLINE inline +#endif + +// Unlike C10_ALWAYS_INLINE, C10_ALWAYS_INLINE_ATTRIBUTE can be used +// on a lambda. +#if defined(_MSC_VER) +// MSVC 14.39 is reasonably recent and doesn't like +// [[msvc::forceinline]] on a lambda, so don't try to use it. +#define C10_ALWAYS_INLINE_ATTRIBUTE +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) +#else +#define C10_ALWAYS_INLINE_ATTRIBUTE +#endif + +#if defined(_MSC_VER) +#define C10_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) +#else +#define C10_ATTR_VISIBILITY_HIDDEN +#endif + +#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. +// See https://github.com/ROCm/hip/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define C10_HOST_DEVICE __host__ __device__ +#define C10_DEVICE __device__ +#define C10_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK +// and C10_MIN_BLOCKS_PER_SM are kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), +// which will also properly respect limits on old architectures. +#define C10_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ + (threads_per_block)))) +// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define C10_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy and + // versatility across all architectures. +#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) +#else +#define C10_HOST_DEVICE +#define C10_HOST +#define C10_DEVICE +#endif + +#if defined(USE_ROCM) +#define C10_HIP_HOST_DEVICE __host__ __device__ +#else +#define C10_HIP_HOST_DEVICE +#endif + +#if defined(USE_ROCM) +// C10_WARP_SIZE is only allowed for device code. +// Host code _must_ use at::cuda::warp_size() +// HIP header used to define warpSize as a constexpr that was either 32 or 64 +// depending on the target device, and then always set it to 64 for host code. +// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we +// set it to something unreasonable to trigger obvious host code errors. +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(__GFX9__) +static constexpr int C10_WARP_SIZE = 64; +#else // __GFX9__ +static constexpr int C10_WARP_SIZE = 32; +#endif // __GFX9__ +#else +static constexpr int C10_WARP_SIZE = 1; +#endif // __HIP_DEVICE_COMPILE__ +#else +#define C10_WARP_SIZE 32 +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +// CUDA_KERNEL_ASSERT checks the assertion +// even when NDEBUG is defined. This is useful for important assertions in CUDA +// code that would otherwise be suppressed when building Release. +#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) +// Those platforms do not support assert() +#define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) +#define SYCL_KERNEL_ASSERT(cond) +#elif defined(_MSC_VER) +#if defined(NDEBUG) +extern "C" { +C10_IMPORT +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void _wassert( + const wchar_t* wexpr, + const wchar_t* wfile, + unsigned line); +#else +#if defined(__CUDA_ARCH__) +__host__ __device__ +#endif // __CUDA_ARCH__ + void + _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +// TODO: This doesn't assert the message because I (chilli) couldn't figure out +// a nice way to convert a char* to a wchar_t* +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#else // __APPLE__, _MSC_VER +#if defined(NDEBUG) +extern "C" { +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void __assert_fail( + const char* expr, + const char* file, + unsigned int line, + const char* func); +#else // __SYCL_DEVICE_ONLY__ +#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) +// CUDA supports __assert_fail function which are common for both device +// and host side code. +__host__ __device__ +#endif + + // This forward declaration matching the declaration of __assert_fail + // exactly how it is in glibc in case parts of the program are compiled with + // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' + // error. Note: On ROCm - this declaration serves for host side compilation. + void + __assert_fail( + const char* assertion, + const char* file, + unsigned int line, + const char* function) noexcept __attribute__((__noreturn__)); + +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +// ROCm disables kernel assert by default for performance considerations. +// Though ROCm supports __assert_fail, it uses kernel printf which has +// a non-negligible performance impact even if the assert condition is +// never triggered. We choose to use abort() instead which will still +// terminate the application but without a more useful error message. +#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) +#define CUDA_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#else +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + msg, __FILE__, static_cast(__LINE__), __func__); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM +#endif // __APPLE__ + +#ifdef __APPLE__ +#include +#endif + +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#define C10_MOBILE 1 +#elif ( \ + defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#define C10_MOBILE 1 +#endif // ANDROID / IOS + +#if defined(C10_MOBILE) && C10_MOBILE +#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline +#else +#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE +#endif + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static constexpr const char field[] = val; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +#ifndef HAS_DEMANGLE +#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) +#define HAS_DEMANGLE 0 +#elif defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) +#define HAS_DEMANGLE 0 +#else +#define HAS_DEMANGLE 1 +#endif +#endif // HAS_DEMANGLE + +#define _C10_PRAGMA__(string) _Pragma(#string) +#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) + +#ifdef __clang__ +#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _C10_PRAGMA_(clang diagnostic ignored flag) +#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define C10_CLANG_DIAGNOSTIC_PUSH() +#define C10_CLANG_DIAGNOSTIC_POP() +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) +#define C10_CLANG_HAS_WARNING(flag) 0 +#endif + +#ifdef __clang__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(clang diagnostic push) \ + _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ + _C10_PRAGMA_(clang diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) + +#elif __GNUC__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(GCC diagnostic push) \ + _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ + _C10_PRAGMA_(GCC diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) + +#else + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) +#define C10_DIAGNOSTIC_POP() + +#endif + +// This macro is used to find older C++ compilers +// that don't support move optimization for return values. + +#if (defined(__GNUC__) && __GNUC__ < 13) || \ + (defined(__clang_major__) && __clang_major__ < 13) +#define C10_RETURN_MOVE_IF_OLD_COMPILER 1 +#else +#define C10_RETURN_MOVE_IF_OLD_COMPILER 0 +#endif + +#endif // C10_MACROS_MACROS_H_ diff --git a/torch/headeronly/macros/build.bzl b/torch/headeronly/macros/build.bzl index 5217c2f7d37d6..9b136951ad139 100644 --- a/torch/headeronly/macros/build.bzl +++ b/torch/headeronly/macros/build.bzl @@ -4,6 +4,7 @@ def define_targets(rules): srcs = [":cmake_macros_h"], hdrs = [ # Following the example from c10 + "Macros.h", "Export.h", ], linkstatic = True, From 79ab84e9b8fe561a55931b2108af45993a670276 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 16 Jul 2025 18:47:04 +0000 Subject: [PATCH 130/457] Fix invalid formatting (#158436) It causes errors under C++20 ``` /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/OperationUtils.mm:330:40: error: call to consteval function 'fmt::fstring<>::fstring' is not a constant expression ``` Indeed the printed value is treated as format string and it may contain special chars in some cases. While this is not true in our case, it can't be determined in compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158436 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/mps/OperationUtils.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 142186b748b17..a58b334307f2f 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -327,7 +327,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { if (exclude_shape) { fmt::format_to(buf_iterator, "-1"); } else { - fmt::format_to(buf_iterator, getArrayRefString(tensor.sizes())); + fmt::format_to(buf_iterator, "{}", getArrayRefString(tensor.sizes())); } } fmt::format_to(buf_iterator, "]"); From 944a140e90389eced1ec38e14cb4345811ed0b1a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Jul 2025 19:15:33 +0000 Subject: [PATCH 131/457] Revert "[cuda][cupy] Improve cupy device placement when device is provided (#158320)" This reverts commit 59f9b25f3cfc635053843372ea29ff4bf754da3f. Reverted https://github.com/pytorch/pytorch/pull/158320 on behalf of https://github.com/wdvr due to reverting because most likely causing test/test_numba_integration.py::TestNumbaIntegration::test_from_cuda_array_interface_inferred_strides to fail ([comment](https://github.com/pytorch/pytorch/pull/158320#issuecomment-3079960616)) --- test/distributed/test_cupy_as_tensor.py | 104 ------------------------ torch/_torch_docs.py | 3 +- torch/csrc/utils/tensor_new.cpp | 2 +- torch/csrc/utils/tensor_numpy.cpp | 12 +-- torch/csrc/utils/tensor_numpy.h | 4 +- 5 files changed, 6 insertions(+), 119 deletions(-) delete mode 100644 test/distributed/test_cupy_as_tensor.py diff --git a/test/distributed/test_cupy_as_tensor.py b/test/distributed/test_cupy_as_tensor.py deleted file mode 100644 index e5b13adf32dde..0000000000000 --- a/test/distributed/test_cupy_as_tensor.py +++ /dev/null @@ -1,104 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -# To run: -# python test/distributed/test_cupy_as_tensor.py - -import os -from dataclasses import dataclass - -import torch -from torch.multiprocessing.reductions import reduce_tensor -from torch.testing._internal.common_distributed import MultiProcContinousTest -from torch.testing._internal.common_utils import ( - requires_cuda_p2p_access, - run_tests, - skipIfRocm, -) - - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -# So that tests are written in device-agnostic way -device_type = "cuda" -device_module = torch.get_device_module(device_type) - - -@dataclass -class CupyWrapper: - data_ptr: int - size_in_bytes: int - - @property - def __cuda_array_interface__(self): - return { - "shape": (self.size_in_bytes,), - "typestr": "|u1", - "data": (self.data_ptr, False), - "version": 3, - } - - -def from_buffer( - data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype -) -> torch.Tensor: - data = torch.as_tensor(CupyWrapper(data_ptr, size_in_bytes), device=device).view( - dtype - ) - assert data.data_ptr() == data_ptr - return data - - -@requires_cuda_p2p_access() -class CupyAsTensorTest(MultiProcContinousTest): - @classmethod - def backend_str(cls): - return "gloo" - - def _init_device(self) -> None: - # init and pin the process to the device - device_module.set_device(self.device) - torch.empty(1, device=self.device) - - @property - def device(self) -> torch.device: - return torch.device(device_type, self.rank) - - @skipIfRocm - def test_cupy_as_tensor(self) -> None: - """ - Test that torch.as_tensor works for cupy array interface - with zero-copy when the pointer is p2p-shared across processes. - """ - self._init_device() - - tensor: torch.Tensor - if self.rank == 1: - # it seems only error from rank non-zero will be caught by this test - tensor = torch.randn(2333, device=self.device) - tensor_meta = reduce_tensor(tensor) - torch.distributed.broadcast_object_list([tensor_meta], src=1) - else: - recv_list = [None] - torch.distributed.broadcast_object_list(recv_list, src=1) - tensor_meta = recv_list[0] - func, args = tensor_meta - args = list(args) - args[6] = self.rank - ipc_tensor = func(*args) - tensor = from_buffer( - ipc_tensor.data_ptr(), - ipc_tensor.numel() * ipc_tensor.element_size(), - self.device, - ipc_tensor.dtype, - ) - - torch.distributed.barrier() - if self.rank == 1: - tensor.fill_(1) - device_module.synchronize() - torch.distributed.barrier() - assert tensor.allclose(tensor, 1) - torch.distributed.barrier() - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 958b040f7f3ed..0766bf7742864 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1006,8 +1006,7 @@ def merge_dicts(*dicts): tensor is constructed using :func:`torch.from_numpy`. If :attr:`data` is a CuPy array, the returned tensor will be located on the same device as the CuPy array unless -specifically overwritten by :attr:`device` or a default device. The device of the CuPy array is inferred from the -pointer of the array using `cudaPointerGetAttributes` unless :attr:`device` is provided. +specifically overwritten by :attr:`device` or a default device. .. seealso:: diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 35511300f703e..45f58cde9a659 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -304,7 +304,7 @@ Tensor internal_new_from_data( TORCH_CHECK( !pin_memory, "Can't pin tensor constructed from __cuda_array_interface__"); - auto tensor = tensor_from_cuda_array_interface(data, device_opt); + auto tensor = tensor_from_cuda_array_interface(data); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 2d9651748c315..c8548884692fd 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -27,9 +27,7 @@ bool is_numpy_int(PyObject* obj) { bool is_numpy_scalar(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } -at::Tensor tensor_from_cuda_array_interface( - PyObject* obj, - std::optional device_opt) { +at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { throw std::runtime_error("PyTorch was compiled without NumPy support"); } @@ -382,9 +380,7 @@ bool is_numpy_scalar(PyObject* obj) { PyArray_IsScalar(obj, ComplexFloating)); } -at::Tensor tensor_from_cuda_array_interface( - PyObject* obj, - std::optional device_opt) { +at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { if (!is_numpy_available()) { throw std::runtime_error("Numpy is not available"); } @@ -493,9 +489,7 @@ at::Tensor tensor_from_cuda_array_interface( // ref: // https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html#cuda-array-interface-version-3 if (data_ptr != nullptr) { - // if device_opt is provided and not nullopt, use it, otherwise infer from - // cudaPointerGetAttributes later in from_blob - return device_opt; + return {}; } else { const auto current_device = at::detail::getCUDAHooks().getCurrentDevice(); return Device( diff --git a/torch/csrc/utils/tensor_numpy.h b/torch/csrc/utils/tensor_numpy.h index 5f93cbb089c21..a7c1d8cf5476e 100644 --- a/torch/csrc/utils/tensor_numpy.h +++ b/torch/csrc/utils/tensor_numpy.h @@ -22,9 +22,7 @@ TORCH_API bool is_numpy_bool(PyObject* obj); TORCH_API bool is_numpy_scalar(PyObject* obj); void warn_numpy_not_writeable(); -at::Tensor tensor_from_cuda_array_interface( - PyObject* obj, - std::optional device_opt = std::nullopt); +at::Tensor tensor_from_cuda_array_interface(PyObject* obj); void validate_numpy_for_dlpack_deleter_bug(); bool is_numpy_dlpack_deleter_bugged(); From f58a680d09e13658a52c6ba05c63c15759846bcc Mon Sep 17 00:00:00 2001 From: fduwjj Date: Wed, 16 Jul 2025 07:13:57 -0700 Subject: [PATCH 132/457] [c10d]Prototype of remote_group_merge (#158287) Tentative implementation of merge_remote_group per the proposal here: [docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89](https://docs.google.com/document/d/13R-1t_yESTvmAjcCN-wQjQQadIEu0JNIdS65uZawZzY/edit?tab=t.0#heading=h.3ctbqqopzc89) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158287 Approved by: https://github.com/d4l3k ghstack dependencies: #157716 --- test/distributed/test_dist2.py | 34 ++++++++++++++ torch/_C/_distributed_c10d.pyi | 8 ++++ torch/csrc/distributed/c10d/Backend.hpp | 16 ++++++- torch/csrc/distributed/c10d/ProcessGroup.cpp | 44 ++++++++++++++++++- torch/csrc/distributed/c10d/ProcessGroup.hpp | 22 ++++++++++ .../distributed/c10d/ProcessGroupGloo.cpp | 14 +++++- .../distributed/c10d/ProcessGroupGloo.hpp | 8 +++- .../distributed/c10d/ProcessGroupNCCL.cpp | 16 ++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 8 +++- .../csrc/distributed/c10d/PyProcessGroup.hpp | 13 ++++++ torch/csrc/distributed/c10d/init.cpp | 20 +++++++++ 11 files changed, 194 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index b4d6c6d02b35f..baaaf0550acda 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -5,6 +5,7 @@ from datetime import timedelta import torch +import torch.distributed as dist import torch.distributed._dist2 as dist2 from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -216,6 +217,39 @@ def test_group_split(self) -> None: else: self.assertEqual(subgroup, None) + def test_remote_group_merge(self) -> None: + group = self.new_group() + subgroup_1 = group.split_group([0], timeout=timedelta(seconds=30)) + subgroup_2 = group.split_group([1], timeout=timedelta(seconds=30)) + if self.rank == 0: + assert subgroup_1 is not None + tcp_store = dist.TCPStore( + host_name=os.environ["MASTER_ADDR"], + port=29781, + world_size=2, + is_master=True, + ) + merged_pg = subgroup_1.merge_remote_group( + tcp_store, 2, timedelta(seconds=40), "merged_pg" + ) + self.assertEqual(merged_pg.size(), 2) + backend = merged_pg._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=40)) + else: + assert subgroup_2 is not None + tcp_store = dist.TCPStore( + host_name=os.environ["MASTER_ADDR"], + port=29781, + world_size=2, + is_master=False, + ) + merged_pg = subgroup_2.merge_remote_group( + tcp_store, 2, timedelta(seconds=40), "merged_pg" + ) + self.assertEqual(merged_pg.size(), 2) + backend = merged_pg._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=40)) + class ProcessGroupGlooTest(Dist2MultiProcessTestCase): device = torch.device("cpu") diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f57bcb3472cc4..20805d56e3702 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -357,6 +357,14 @@ class ProcessGroup: pg_options: Optional[Backend.Options] = None, group_desc: Optional[str] = None, ) -> Optional[ProcessGroup]: ... + def merge_remote_group( + self, + store: Store, + size: int, + timeout: timedelta, + group_name: Optional[str] = None, + group_desc: Optional[str] = None, + ) -> ProcessGroup: ... def abort(self) -> None: ... def set_timeout(self, timeout: timedelta) -> None: ... def shutdown(self) -> None: ... diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 0f1c5116803f2..070cdb7234b4c 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -388,14 +388,26 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of enableCollectivesTiming."); } - virtual c10::intrusive_ptr splitBackend( + virtual c10::intrusive_ptr split( const std::vector& ranks, const c10::intrusive_ptr opts) { TORCH_CHECK( false, "Backend ", getBackendName(), - " is missing implementation of splitBackend."); + " is missing implementation of split."); + } + + virtual c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr opts, + const int& rank, + const int& size) { + TORCH_CHECK( + false, + "Backend ", + getBackendName(), + " is missing implementation of merge."); } bool hasHooks() const { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 197fd9014b3a9..3f183d804129a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -190,7 +189,7 @@ c10::intrusive_ptr ProcessGroup::splitGroup( backendOpts->group_name = groupName; backendOpts->timeout = timeout.has_value() ? timeout.value() : backendOpts->timeout; - auto splitBackend = parentBackend->splitBackend(sorted_ranks, backendOpts); + auto splitBackend = parentBackend->split(sorted_ranks, backendOpts); if (splitBackend == nullptr) { continue; } @@ -216,6 +215,47 @@ c10::intrusive_ptr ProcessGroup::splitGroup( return newGroup; } +c10::intrusive_ptr ProcessGroup::mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size) { + c10::intrusive_ptr newGroup; + // We assume rank number is within the range of int32_t, so it won't overflow. + int rank = static_cast(store->add("mergeGroupRank", 1) - 1); + // TODO: Do we need to check all groups have same deviceTypeToBackendType_? + for (const auto& pair : deviceTypeToBackendType_) { + c10::DeviceType deviceType = pair.first; + BackendType backendType = pair.second; + + auto parentBackend = getBackend(deviceType); + auto backendOpts = parentBackend->getBackendOptions(); + std::string groupName = opts.group_name.has_value() + ? opts.group_name.value() + : c10::str(getGroupName(), ":merge"); + backendOpts->group_name = groupName; + backendOpts->timeout = opts.timeout; + auto mergedBackend = parentBackend->merge(store, backendOpts, rank, size); + + std::string groupDesc = opts.group_desc.has_value() + ? opts.group_desc.value() + : c10::str(getGroupDesc(), ":merge"); + mergedBackend->setGroupDesc(groupDesc); + + // Historically, we have been using one process_group to map to all + // backends. but in our new design, we will have one process_group per + // backend. This logic is mostly for backward compatibility. + if (!newGroup) { + newGroup = c10::make_intrusive(store, rank, size); + newGroup->setDefaultBackend(backendType_); + newGroup->setGroupName(groupName); + newGroup->setGroupDesc(groupDesc); + } + newGroup->setBackend(deviceType, backendType, mergedBackend); + } + + return newGroup; +} + } // namespace c10d namespace { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 5939f23e2972b..437564ff9ac69 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -71,6 +71,21 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input(); // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: + struct TORCH_API MergeOptions : torch::CustomClassHolder { + explicit MergeOptions( + const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout, + const std::optional group_name = std::nullopt, + const std::optional group_desc = std::nullopt) + : timeout(timeout), group_name(group_name), group_desc(group_desc) {} + ~MergeOptions() override = default; + MergeOptions(const MergeOptions&) = delete; + MergeOptions& operator=(const MergeOptions&) = delete; + + std::chrono::milliseconds timeout; + std::optional group_name; + std::optional group_desc; + }; + enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, @@ -967,6 +982,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::optional> opts, const std::optional& groupDesc); + // This creates a new subgroup using the specified ranks. + // The current rank must be included in the list of new_ranks. + virtual c10::intrusive_ptr mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size); + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 30301524bc575..045e46f9129c9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -697,7 +697,7 @@ const std::vector& ProcessGroupGloo::groupRanks() const { return options_->global_ranks_in_group; } -c10::intrusive_ptr ProcessGroupGloo::splitBackend( +c10::intrusive_ptr ProcessGroupGloo::split( const std::vector& ranks, const c10::intrusive_ptr opts) { auto it = std::find(ranks.begin(), ranks.end(), rank_); @@ -726,6 +726,18 @@ c10::intrusive_ptr ProcessGroupGloo::splitBackend( return c10::static_intrusive_pointer_cast(pg); } +c10::intrusive_ptr ProcessGroupGloo::merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr opts, + const int& rank, + const int& size) { + auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + auto pg = c10::make_intrusive( + store->clone(), rank, size, glooOpts); + return c10::static_intrusive_pointer_cast(pg); +} + void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 0ba2d416aedff..655679489adb5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -308,10 +308,16 @@ class TORCH_API ProcessGroupGloo : public Backend { return c10::static_intrusive_pointer_cast(options_); } - c10::intrusive_ptr splitBackend( + c10::intrusive_ptr split( const std::vector& ranks, const c10::intrusive_ptr opts) override; + c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr opts, + const int& rank, + const int& size) override; + const std::vector& groupRanks() const; c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3dc7abbb7e54c..a0c546a405f59 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1311,13 +1311,13 @@ void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } -c10::intrusive_ptr ProcessGroupNCCL::splitBackend( +c10::intrusive_ptr ProcessGroupNCCL::split( const std::vector& ranks, const c10::intrusive_ptr opts) { auto deviceIdx = guessDeviceId(); TORCH_CHECK( deviceIdx >= 0, - "ProcessGroupNCCL::splitBackend: rank ", + "ProcessGroupNCCL::split: rank ", rank_, " has no device is bound to this rank."); auto device = at::Device(at::DeviceType::CUDA, deviceIdx); @@ -1350,6 +1350,18 @@ c10::intrusive_ptr ProcessGroupNCCL::splitBackend( return c10::static_intrusive_pointer_cast(pg); } +c10::intrusive_ptr ProcessGroupNCCL::merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr opts, + const int& rank, + const int& size) { + auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options."); + auto pg = c10::make_intrusive( + store->clone(), rank, size, ncclOpts); + return c10::static_intrusive_pointer_cast(pg); +} + bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index d7bb02e912c81..810f8db9fd7d8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -975,10 +975,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { void enableCollectivesTiming() override; - c10::intrusive_ptr splitBackend( + c10::intrusive_ptr split( const std::vector& ranks, const c10::intrusive_ptr opts) override; + c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr opts, + const int& rank, + const int& size) override; + // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 854ea596aba8f..c54004310517f 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -166,6 +166,19 @@ class PyProcessGroup : public ProcessGroup { group_desc); } + c10::intrusive_ptr mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size) override { + PYBIND11_OVERRIDE( + c10::intrusive_ptr, /* Return type */ + ProcessGroup, /* Parent class */ + mergeRemoteGroup, /* Name of function in C++ */ + store, + opts, + size); + } + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 5dfc99a893c7d..8f617c269ff9d 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2071,6 +2071,26 @@ communication mechanism. py::arg("opts") = std::nullopt, py::arg("groupDesc") = std::nullopt, py::call_guard()) + .def( + "merge_remote_group", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const c10::intrusive_ptr<::c10d::Store>& store, + int size, + std::chrono::milliseconds timeout, + std::optional groupName, + std::optional groupDesc) { + ::c10d::ProcessGroup::MergeOptions opts; + opts.timeout = timeout; + opts.group_name = groupName; + opts.group_desc = groupDesc; + return self->mergeRemoteGroup(store, opts, size); + }, + py::arg("store"), + py::arg("size"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::arg("group_name") = std::nullopt, + py::arg("group_desc") = std::nullopt, + py::call_guard()) .def( "abort", &::c10d::ProcessGroup::abort, From 1cc62c2cb91e56ae50494f88f369cd6ec466a118 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 16 Jul 2025 19:53:08 +0000 Subject: [PATCH 133/457] [export] Update docs (#157750) Preview: https://docs-preview.pytorch.org/pytorch/pytorch/157750/export.html Changes: * Rename draft_export.md -> export.draft_export.md for consistency. * Removed non-strict section in export, instead pointed to programming model doc. * Extended "Expressing Dynamism" section to include Dim hints, ShapeCollection, and AdditionalInputs. * Removed Specialization section in favor of programming model doc * Added pt2 archive doc * Cleaned up sidebar Pull Request resolved: https://github.com/pytorch/pytorch/pull/157750 Approved by: https://github.com/pianpwk --- docs/source/export.md | 734 ++++++------------ docs/source/{ => export}/draft_export.md | 15 +- .../{export.ir_spec.md => export/ir_spec.md} | 0 .../programming_model.md} | 14 +- docs/source/export/pt2_archive.md | 122 +++ docs/source/torch.compiler_aot_inductor.md | 4 +- docs/source/torch.compiler_ir.md | 2 + torch/export/dynamic_shapes.py | 33 +- torch/export/pt2_archive/_package.py | 15 +- torch/export/unflatten.py | 2 +- 10 files changed, 410 insertions(+), 531 deletions(-) rename docs/source/{ => export}/draft_export.md (97%) rename docs/source/{export.ir_spec.md => export/ir_spec.md} (100%) rename docs/source/{export.programming_model.md => export/programming_model.md} (98%) create mode 100644 docs/source/export/pt2_archive.md diff --git a/docs/source/export.md b/docs/source/export.md index 9d57614a14adc..0f0deebc65108 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -2,11 +2,6 @@ # torch.export -:::{warning} -This feature is a prototype under active development and there WILL BE -BREAKING CHANGES in the future. -::: - ## Overview {func}`torch.export.export` takes a {class}`torch.nn.Module` and produces a traced graph @@ -130,10 +125,10 @@ level). Note that users can still use {func}`torch.fx.symbolic_trace` as a preprocessing step before `torch.export`. Compared to {func}`torch.jit.script`, `torch.export` does not capture Python -control flow or data structures, but it supports more Python language -features due to its comprehensive coverage over Python bytecodes. -The resulting graphs are simpler and only have straight line control -flow, except for explicit control flow operators. +control flow or data structures, unless using explicit {ref}`control flow operators `, +but it supports more Python language features due to its comprehensive coverage +over Python bytecodes. The resulting graphs are simpler and only have straight +line control flow, except for explicit control flow operators. Compared to {func}`torch.jit.trace`, `torch.export` is sound: it can trace code that performs integer computation on sizes and records @@ -142,10 +137,8 @@ trace is valid for other inputs. ## Exporting a PyTorch Model -### An Example - The main entrypoint is through {func}`torch.export.export`, which takes a -callable ({class}`torch.nn.Module`, function, or method) and sample inputs, and +{class}`torch.nn.Module` and sample inputs, and captures the computation graph into an {class}`torch.export.ExportedProgram`. An example: @@ -236,187 +229,26 @@ Inspecting the `ExportedProgram`, we can note the following: - The {class}`torch.fx.Graph` contains the computation graph of the original program, along with records of the original code for easy debugging. - The graph contains only `torch.ops.aten` operators found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) - and custom operators, and is fully functional, without any inplace operators - such as `torch.add_`. + and custom operators. - The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no `get_attr` nodes in the graph, which previously existed in the result of {func}`torch.fx.symbolic_trace`. - The {class}`torch.export.ExportGraphSignature` models the input and output signature, along with specifying which inputs are parameters. - The resulting shape and dtype of tensors produced by each node in the graph is - noted. For example, the `convolution` node will result in a tensor of dtype + noted. For example, the `conv2d` node will result in a tensor of dtype `torch.float32` and shape (1, 16, 256, 256). -(non-strict-export)= - -### Non-Strict Export - -In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. - -In *non-strict mode*, we trace through the program using the Python interpreter. -Your code will execute exactly as it would in eager mode; the only difference is -that all Tensor objects will be replaced by ProxyTensors, which will record all -their operations into a graph. - -In *strict* mode, which is currently the default, we first trace through the -program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not -actually execute your Python code. Instead, it symbolically analyzes it and -builds a graph based on the results. This analysis allows torch.export to -provide stronger guarantees about safety, but not all Python code is supported. - -An example of a case where one might want to use non-strict mode is if you run -into a unsupported TorchDynamo feature that might not be easily solved, and you -know the python code is not exactly needed for computation. For example: - -```python -import contextlib -import torch - -class ContextManager(): - def __init__(self): - self.count = 0 - def __enter__(self): - self.count += 1 - def __exit__(self, exc_type, exc_value, traceback): - self.count -= 1 - -class M(torch.nn.Module): - def forward(self, x): - with ContextManager(): - return x.sin() + x.cos() - -export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully -export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager -``` - -In this example, the first call using non-strict mode (through the -`strict=False` flag) traces successfully whereas the second call using strict -mode (default) results with a failure, where TorchDynamo is unable to support -context managers. One option is to rewrite the code (see {ref}`Limitations of torch.export `), -but seeing as the context manager does not affect the tensor -computations in the model, we can go with the non-strict mode's result. - -(training-export)= - -### Export for Training and Inference - -In PyTorch 2.5, we introduced a new API called {func}`export_for_training`. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. - -In this API, we produce the most generic IR that contains all ATen operators -(including both functional and non-functional) which can be used to train in -eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization -and will soon be the default IR of torch.export.export. To read further about -the motivation behind this change, please refer to - - -When this API is combined with {func}`run_decompositions()`, you should be able to get inference IR with -any desired decomposition behavior. - -To show some examples: - -```python -class ConvBatchnorm(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv = torch.nn.Conv2d(1, 3, 1, 1) - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return (x,) - -mod = ConvBatchnorm() -inp = torch.randn(1, 1, 3, 3) - -ep_for_training = torch.export.export_for_training(mod, (inp,)) -print(ep_for_training) -``` - -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) - batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) - return (batch_norm,) -``` - -From the above output, you can see that {func}`export_for_training` produces pretty much the same ExportedProgram -as {func}`export` except for the operators in the graph. You can see that we captured batch_norm in the most general -form. This op is non-functional and will be lowered to different ops when running inference. - -You can also go from this IR to an inference IR via {func}`run_decompositions` with arbitrary customizations. - -```python -# Lower to core aten inference IR, but keep conv2d -decomp_table = torch.export.default_decompositions() -del decomp_table[torch.ops.aten.conv2d.default] -ep_for_inference = ep_for_training.run_decompositions(decomp_table) - -print(ep_for_inference) -``` - -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] - return (getitem_3, getitem_4, add, getitem) -``` - -Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR -containing core aten operators except for `conv2d`. - -You can do even more customization by directly registering your chosen decomposition behaviors. - -You can do even more customizations by directly registering custom decomp behaviour - -```python -# Lower to core aten inference IR, but customize conv2d -decomp_table = torch.export.default_decompositions() - -def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): - return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) - -decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function -ep_for_inference = ep_for_training.run_decompositions(decomp_table) - -print(ep_for_inference) -``` - -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) - mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; - return (getitem_3, getitem_4, add, getitem) -``` - -### Expressing Dynamism +## Expressing Dynamism By default `torch.export` will trace the program assuming all input shapes are **static**, and specializing the exported program to those dimensions. However, some dimensions, such as a batch dimension, can be dynamic and vary from run to run. Such dimensions must be specified by using the {func}`torch.export.Dim` API to create them and by passing them into -{func}`torch.export.export` through the `dynamic_shapes` argument. An example: +{func}`torch.export.export` through the `dynamic_shapes` argument. + +An example: ```python import torch @@ -444,7 +276,7 @@ example_args = (torch.randn(32, 64), torch.randn(32, 128)) # Create a dynamic batch size batch = Dim("batch") # Specify that the first dimension of each input is that batch size -dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} +dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}} exported_program: torch.export.ExportedProgram = export( M(), args=example_args, dynamic_shapes=dynamic_shapes @@ -488,211 +320,239 @@ Some additional things to note: [The 0/1 Specialization Problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk) for an in-depth discussion of this topic. -We can also specify more expressive relationships between input shapes, such as -where a pair of shapes might differ by one, a shape might be double of -another, or a shape is even. An example: -```python -class M(torch.nn.Module): - def forward(self, x, y): - return x + y[1:] +In the example, we used `Dim("batch")` to create a dynamic dimension. This is +the most explicit way to specify dynamism. We can also use `Dim.DYNAMIC` and +`Dim.AUTO` to specify dynamism. We will go over both methods in the next section. -x, y = torch.randn(5), torch.randn(6) -dimx = torch.export.Dim("dimx", min=3, max=6) -dimy = dimx + 1 +### Named Dims -exported_program = torch.export.export( - M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), -) -print(exported_program) +For every dimension specified with `Dim("name")`, we will allocate a symbolic +shape. Specifying a `Dim` with the same name will result in the same symbol +to be generated. This allows users to specify what symbols are allocated for +each input dimension. + +```python +batch = Dim("batch") +dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}} ``` +For each `Dim`, we can specify minimum and maximum values. We also allow +specifying relations between `Dim`s in univariate linear expressions: `A * dim + B`. +This allows users to specify more complex constraints like integer divisibility +for dynamic dimensions. These features allow for users to place explicit +restrictions on the dynamic behavior of the `ExportedProgram` produced. + ```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): - # code: return x + y[1:] - slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) - add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) - return (add,) +dx = Dim("dx", min=4, max=256) +dh = Dim("dh", max=512) +dynamic_shapes = { + "x": (dx, None), + "y": (2 * dx, dh), +} +``` + +However, `ConstraintViolationErrors` will be raised if the while tracing, we emit guards +that conflict with the relations or static/dynamic specifications given. For +example, in the above specification, the following is asserted: + +* `x.shape[0]` is to have range `[4, 256]`, and related to `y.shape[0]` by `y.shape[0] == 2 * x.shape[0]`. +* `x.shape[1]` is static. +* `y.shape[1]` has range `[0, 512]`, and is unrelated to any other dimension. + +If any of these assertions are found to be incorrect while tracing (ex. +`x.shape[0]` is static, or `y.shape[1]` has a smaller range, or +`y.shape[0] != 2 * x.shape[0]`), then a `ConstraintViolationError` will be +raised, and the user will need to change their `dynamic_shapes` specification. + +### Dim Hints -Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} +Instead of explicitly specifying dynamism using `Dim("name")`, we can let +`torch.export` infer the ranges and relationships of the dynamic values using +`Dim.DYNAMIC`. This is also a more convenient way to specify dynamism when you +don't know specifically *how* dynamic your dynamic values are. + +```python +dynamic_shapes = { + "x": (Dim.DYNAMIC, None), + "y": (Dim.DYNAMIC, Dim.DYNAMIC), +} ``` -Some things to note: +We can also specify min/max values for `Dim.DYNAMIC`, which will serve as hints +to export. But if while tracing export found the range to be different, it will +automatically update the range without raising an error. We also cannot specify +relationships between dynamic values. Instead, this will be inferred by export, +and exposed to users through an inspection of assertions within the graph. In +this method of specifying dynamism, `ConstraintViolationErrors` will **only** be +raised if the specified value is inferred to be **static**. -- By specifying `{0: dimx}` for the first input, we see that the resulting - shape of the first input is now dynamic, being `[s0]`. And now by specifying - `{0: dimy}` for the second input, we see that the resulting shape of the - second input is also dynamic. However, because we expressed `dimy = dimx + 1`, - instead of `y`'s shape containing a new symbol, we see that it is - now being represented with the same symbol used in `x`, `s0`. We can - see that relationship of `dimy = dimx + 1` is being shown through `s0 + 1`. -- Looking at the range constraints, we see that `s0` has the range [3, 6], - which is specified initially, and we can see that `s0 + 1` has the solved - range of [4, 7]. +An even more convenient way to specify dynamism is to use `Dim.AUTO`, which will +behave like `Dim.DYNAMIC`, but will **not** raise an error if the dimension is +inferred to be static. This is useful for when you have no idea what the dynamic +values are, and want to export the program with a "best effort" dynamic approach. -### Serialization +### ShapesCollection -To save the `ExportedProgram`, users can use the {func}`torch.export.save` and -{func}`torch.export.load` APIs. A convention is to save the `ExportedProgram` -using a `.pt2` file extension. +When specifying which inputs are dynamic via `dynamic_shapes`, we must specify +the dynamism of every input. For example, given the following inputs: -An example: +```python +args = {"x": tensor_x, "others": [tensor_y, tensor_z]} +``` + +we would need to specify the dynamism of `tensor_x`, `tensor_y`, and `tensor_z` +along with the dynamic shapes: ```python -import torch -import io +# With named-Dims +dim = torch.export.Dim(...) +dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]} -class MyModule(torch.nn.Module): - def forward(self, x): - return x + 10 +torch.export(..., args, dynamic_shapes=dynamic_shapes) +``` -exported_program = torch.export.export(MyModule(), torch.randn(5)) +However, this is particularly complicated as we need to specify the +`dynamic_shapes` specification in the same nested input structure as the input +arguments. Instead, an easier way to specify dynamic shapes is with the helper +utility {class}`torch.export.ShapesCollection`, where instead of specifying the +dynamism of every single input, we can just assign directly which input +dimensions are dynamic. -torch.export.save(exported_program, 'exported_program.pt2') -saved_exported_program = torch.export.load('exported_program.pt2') +```python +dim = torch.export.Dim(...) +dynamic_shapes = torch.export.ShapesCollection() +dynamic_shapes[tensor_x] = (dim, dim + 1, 8) +dynamic_shapes[tensor_y] = {0: dim * 2} + +torch.export(..., args, dynamic_shapes=dynamic_shapes) ``` -### Specializations +### AdditionalInputs -A key concept in understanding the behavior of `torch.export` is the -difference between *static* and *dynamic* values. +In the case where you don't know how dynamic your inputs are, but you have an +ample set of testing or profiling data that can provide a fair sense of +representative inputs for a model, you can use +{class}`torch.export.AdditionalInputs` in place of `dynamic_shapes`. You can +specify all the possible inputs used to trace the program, and +`AdditionalInputs` will infer which inputs are dynamic based on which input +shapes are changing. -A *dynamic* value is one that can change from run to run. These behave like -normal arguments to a Python function—you can pass different values for an -argument and expect your function to do the right thing. Tensor *data* is -treated as dynamic. +Example: -A *static* value is a value that is fixed at export time and cannot change -between executions of the exported program. When the value is encountered during -tracing, the exporter will treat it as a constant and hard-code it into the -graph. +```python +args0, kwargs0 = ... # example inputs for export -When an operation is performed (e.g. `x + y`) and all inputs are static, then -the output of the operation will be directly hard-coded into the graph, and the -operation won’t show up (i.e. it will get constant-folded). +# other representative inputs that the exported program will run on +dynamic_shapes = torch.export.AdditionalInputs() +dynamic_shapes.add(args1, kwargs1) +... +dynamic_shapes.add(argsN, kwargsN) -When a value has been hard-coded into the graph, we say that the graph has been -*specialized* to that value. +torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) +``` -The following values are static: +## Serialization -#### Input Tensor Shapes +To save the `ExportedProgram`, users can use the {func}`torch.export.save` and +{func}`torch.export.load` APIs. The resulting file is a zipfile with a specific +structure. The details of the structure are defined in the +{ref}`PT2 Archive Spec `. -By default, `torch.export` will trace the program specializing on the input -tensors' shapes, unless a dimension is specified as dynamic via the -`dynamic_shapes` argument to `torch.export`. This means that if there exists -shape-dependent control flow, `torch.export` will specialize on the branch -that is being taken with the given sample inputs. For example: +An example: ```python import torch -from torch.export import export +import io -class Mod(torch.nn.Module): +class MyModule(torch.nn.Module): def forward(self, x): - if x.shape[0] > 5: - return x + 1 - else: - return x - 1 + return x + 10 -example_inputs = (torch.rand(10, 2),) -exported_program = export(Mod(), example_inputs) -print(exported_program) -``` +exported_program = torch.export.export(MyModule(), torch.randn(5)) -```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 2]"): - # code: return x + 1 - add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) - return (add,) +torch.export.save(exported_program, 'exported_program.pt2') +saved_exported_program = torch.export.load('exported_program.pt2') ``` -The conditional of (`x.shape[0] > 5`) does not appear in the -`ExportedProgram` because the example inputs have the static -shape of (10, 2). Since `torch.export` specializes on the inputs' static -shapes, the else branch (`x - 1`) will never be reached. To preserve the dynamic -branching behavior based on the shape of a tensor in the traced graph, -{func}`torch.export.Dim` will need to be used to specify the dimension -of the input tensor (`x.shape[0]`) to be dynamic, and the source code will -need to be {ref}`rewritten `. +(training-export)= + +## Export IR, Decompositions -Note that tensors that are part of the module state (e.g. parameters and -buffers) always have static shapes. +The graph produced by `torch.export` returns a graph containing only ATen +operators, which are the basic unit of computation in PyTorch. As there are over +3000 ATen operators, export provides a way to narrow down the operator set used +in the graph based on certain characteristics, creating different IRs. -#### Python Primitives +By default, export produces the most generic IR which contains all ATen +operators, including both functional and non-functional operators. A functional +operator is one that does not contain any mutations or aliasing of the inputs. +This operator set also allows you to train in eager PyTorch Autograd. -`torch.export` also specializes on Python primitives, -such as `int`, `float`, `bool`, and `str`. However they do have dynamic -variants such as `SymInt`, `SymFloat`, and `SymBool`. +However, if you want to use the IR for inference, or decrease the amount of +operators being used, you can lower the graph through the {func}`ExportedProgram.run_decompositions` API. -For example: +* By specifying an empty set to the `decomp_table` argument, we get rid of all + non-functional operators, reducing the operator set to ~2000 operators. This + is ideal for inference cases as there are no mutations or aliasing, making + it easy to write optimization passes. +* By specifying None to `decomp_table` argument, we can reduce the operator set + to just the {ref}`Core ATen Operator Set `, which is a + collection of only ~180 operators. This IR is optimal for backends who do + not want to reimplement all ATen operators. ```python -import torch -from torch.export import export +class ConvBatchnorm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) -class Mod(torch.nn.Module): - def forward(self, x: torch.Tensor, const: int, times: int): - for i in range(times): - x = x + const - return x + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) -example_inputs = (torch.rand(2, 2), 1, 3) -exported_program = export(Mod(), example_inputs) -print(exported_program) -``` +mod = ConvBatchnorm() +inp = torch.randn(1, 1, 3, 3) -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[2, 2]", const, times): - # code: x = x + const - add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) - add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) - add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) - return (add_2,) +ep_for_training = torch.export.export(mod, (inp,)) +ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) ``` -Because integers are specialized, the `torch.ops.aten.add.Tensor` operations -are all computed with the hard-coded constant `1`, rather than `const`. If -a user passes a different value for `const` at runtime, like 2, than the one used -during export time, 1, this will result in an error. -Additionally, the `times` iterator used in the `for` loop is also "inlined" -in the graph through the 3 repeated `torch.ops.aten.add.Tensor` calls, and the -input `times` is never used. - -#### Python Containers - -Python containers (`List`, `Dict`, `NamedTuple`, etc.) are considered to -have static structure. +A tutorial on how to use this API can be found +[here](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html#ir-decompositions). (limitations-of-torch-export)= ## Limitations of torch.export -### Graph Breaks - As `torch.export` is a one-shot process for capturing a computation graph from a PyTorch program, it might ultimately run into untraceable parts of programs as it is nearly impossible to support tracing all PyTorch and Python features. In the case of `torch.compile`, an unsupported operation will cause a "graph break" and the unsupported operation will be run with default Python evaluation. In contrast, `torch.export` will require users to provide additional -information or rewrite parts of their code to make it traceable. As the -tracing is based on TorchDynamo, which evaluates at the Python -bytecode level, there will be significantly fewer rewrites required compared to -previous tracing frameworks. +information or rewrite parts of their code to make it traceable. + +{ref}`Draft-export ` is a great resource for listing out +graphs breaks that will be encountered when tracing the program, along with +additional debug information to solve those errors. -When a graph break is encountered, {ref}`ExportDB ` is a great -resource for learning about the kinds of programs that are supported and -unsupported, along with ways to rewrite programs to make them traceable. +{ref}`ExportDB ` is also great resource for learning about the +kinds of programs that are supported and unsupported, along with ways to rewrite +programs to make them traceable. -An option to get past dealing with this graph breaks is by using -{ref}`non-strict export ` +### TorchDynamo unsupported + +When using `torch.export` with `strict=True`, this will use TorchDynamo to +evaluate the program at the Python bytecode level to trace the program into a +graph. Compared to previous tracing frameworks, there will be significantly +fewer rewrites required to make a program traceable, but there will still be +some Python features that are unsupported. An option to get past dealing with +this graph breaks is by using +{ref}`non-strict export ` through changing the `strict` flag +to `strict=False`. (data-shape-dependent-control-flow)= @@ -705,13 +565,18 @@ number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support {ref}`torch.cond ` to express if-else like control flow (more coming soon!). -### Missing Fake/Meta/Abstract Kernels for Operators +You can also refer to this +[tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html#data-dependent-errors) +for more ways of addressing data-dependent errors. + +### Missing Fake/Meta Kernels for Operators -When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is -required for all operators. This is used to reason about the input/output shapes -for this operator. +When tracing, a FakeTensor kernel (aka meta kernel) is required for all +operators. This is used to reason about the input/output shapes for this +operator. -Please see {func}`torch.library.register_fake` for more details. +Please see this [tutorial](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html) +for more details. In the unfortunate case where your model uses an ATen operator that is does not have a FakeTensor kernel implementation yet, please file an issue. @@ -722,202 +587,89 @@ have a FakeTensor kernel implementation yet, please file an issue. :caption: Additional Links for Export Users :maxdepth: 1 -export.programming_model -export.ir_spec -draft_export -torch.compiler_transformations -torch.compiler_ir -generated/exportdb/index +export/programming_model +export/ir_spec +export/pt2_archive +export/draft_export cond +generated/exportdb/index +torch.compiler_aot_inductor +torch.compiler_ir ``` ```{toctree} :caption: Deep Dive for PyTorch Developers :maxdepth: 1 -torch.compiler_dynamo_overview -torch.compiler_dynamo_deepdive torch.compiler_dynamic_shapes torch.compiler_fake_tensor +torch.compiler_transformations ``` ## API Reference ```{eval-rst} .. automodule:: torch.export -``` -```{eval-rst} -.. autofunction:: export -``` +.. autofunction:: torch.export.export -```{eval-rst} -.. autofunction:: save -``` +.. autoclass:: torch.export.ExportedProgram + :members: + :exclude-members: __init__ -```{eval-rst} -.. autofunction:: load -``` +.. automodule:: torch.export.dynamic_shapes + :members: Dim, ShapesCollection, AdditionalInputs, refine_dynamic_shapes_from_suggested_fixes -```{eval-rst} -.. autofunction:: draft_export -``` +.. autofunction:: torch.export.save -```{eval-rst} -.. autofunction:: register_dataclass -``` +.. autofunction:: torch.export.load -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.Dim -``` +.. autofunction:: torch.export.pt2_archive._package.package_pt2 -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.ShapesCollection - - .. automethod:: dynamic_shapes -``` - -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs - - .. automethod:: add - .. automethod:: dynamic_shapes - .. automethod:: verify -``` +.. autofunction:: torch.export.pt2_archive._package.load_pt2 -```{eval-rst} -.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes -``` - -```{eval-rst} -.. autoclass:: ExportedProgram - - .. attribute:: graph - .. attribute:: graph_signature - .. attribute:: state_dict - .. attribute:: constants - .. attribute:: range_constraints - .. attribute:: module_call_graph - .. attribute:: example_inputs - .. automethod:: module - .. automethod:: run_decompositions -``` - -```{eval-rst} -.. autoclass:: ExportGraphSignature -``` +.. autofunction:: torch.export.draft_export -```{eval-rst} -.. autoclass:: ModuleCallSignature -``` +.. automodule:: torch.export.unflatten + :members: -```{eval-rst} -.. autoclass:: ModuleCallEntry -``` +.. autofunction:: torch.export.register_dataclass -```{eval-rst} .. automodule:: torch.export.decomp_utils -``` - -```{eval-rst} -.. autoclass:: CustomDecompTable - - .. automethod:: copy - .. automethod:: items - .. automethod:: keys - .. automethod:: materialize - .. automethod:: pop - .. automethod:: update -``` - -```{eval-rst} -.. autofunction:: torch.export.exported_program.default_decompositions -``` - -```{eval-rst} -.. automodule:: torch.export.exported_program -``` - -```{eval-rst} -.. automodule:: torch.export.graph_signature -``` - -```{eval-rst} -.. autoclass:: ExportGraphSignature - - .. automethod:: replace_all_uses - .. automethod:: get_replace_hook -``` - -```{eval-rst} -.. autoclass:: ExportBackwardSignature -``` - -```{eval-rst} -.. autoclass:: InputKind -``` - -```{eval-rst} -.. autoclass:: InputSpec -``` - -```{eval-rst} -.. autoclass:: OutputKind -``` + :members: + :ignore-module-all: + :undoc-members: -```{eval-rst} -.. autoclass:: OutputSpec -``` - -```{eval-rst} -.. autoclass:: SymIntArgument -``` - -```{eval-rst} -.. autoclass:: SymBoolArgument -``` +.. automodule:: torch.export.experimental + :members: + :ignore-module-all: -```{eval-rst} -.. autoclass:: SymFloatArgument -``` +.. automodule:: torch.export.passes + :members: -```{eval-rst} -.. autoclass:: CustomObjArgument -``` +.. automodule:: torch.export.pt2_archive + :members: + :ignore-module-all: -```{eval-rst} -.. py:module:: torch.export.dynamic_shapes -``` +.. automodule:: torch.export.pt2_archive.constants + :members: + :ignore-module-all: -```{eval-rst} -.. py:module:: torch.export.custom_ops -``` +.. automodule:: torch.export.exported_program + :members: + :ignore-module-all: + :exclude-members: ExportedProgram -```{eval-rst} -.. automodule:: torch.export.unflatten - :members: -``` +.. automodule:: torch.export.custom_ops + :members: + :ignore-module-all: -```{eval-rst} .. automodule:: torch.export.custom_obj -``` + :members: + :ignore-module-all: -```{eval-rst} -.. automodule:: torch.export.experimental -``` - -```{eval-rst} -.. automodule:: torch.export.passes -``` - -```{eval-rst} -.. autofunction:: torch.export.passes.move_to_device_pass -``` - -```{eval-rst} -.. automodule:: torch.export.pt2_archive -``` - -```{eval-rst} -.. automodule:: torch.export.pt2_archive.constants +.. automodule:: torch.export.graph_signature + :members: + :ignore-module-all: + :undoc-members: ``` diff --git a/docs/source/draft_export.md b/docs/source/export/draft_export.md similarity index 97% rename from docs/source/draft_export.md rename to docs/source/export/draft_export.md index cc7247d3b526d..b1ec6ca5d44e6 100644 --- a/docs/source/draft_export.md +++ b/docs/source/export/draft_export.md @@ -1,4 +1,4 @@ -(draft-export)= +(export.draft_export)= # Draft Export @@ -126,7 +126,7 @@ Running the `tlparse` command in the terminal will generate a [tlparse](https://github.com/pytorch/tlparse) HTML report. Here is an example of the `tlparse` report: -```{image} _static/img/export/draft_export_report.png +```{image} ../_static/img/export/draft_export_report.png ``` Clicking into the Data Dependent Error, we will see the following page which @@ -136,7 +136,7 @@ contains information to help debug this error. Specifically, it contains: - A list of local variables and their shapes - Information for how this guard was created -```{image} _static/img/export/draft_export_report_dde.png +```{image} ../_static/img/export/draft_export_report_dde.png ``` ## The returned Exported Program @@ -251,12 +251,3 @@ and produce a runnable artifact. This optimized version can then be used for deployment. In parallel, we can utilize the report generated by draft-export to identify and fix `torch.export` errors that were encountered so that the original model can be directly traceable with `torch.export`. - -```{toctree} -:caption: Additional Links -:maxdepth: 1 - -torch.compiler_fake_tensor -torch.compiler_dynamic_shapes -torch.compiler_aot_inductor -``` diff --git a/docs/source/export.ir_spec.md b/docs/source/export/ir_spec.md similarity index 100% rename from docs/source/export.ir_spec.md rename to docs/source/export/ir_spec.md diff --git a/docs/source/export.programming_model.md b/docs/source/export/programming_model.md similarity index 98% rename from docs/source/export.programming_model.md rename to docs/source/export/programming_model.md index 9a21db78464aa..d4b81b223fa2e 100644 --- a/docs/source/export.programming_model.md +++ b/docs/source/export/programming_model.md @@ -1,4 +1,4 @@ -(export-programming-model)= +(export.programming_model)= # torch.export Programming Model @@ -15,7 +15,9 @@ on different inputs as long as they satisfy the same conditions. The basic output of {func}`torch.export.export` is a single graph of PyTorch operations, with associated metadata. The exact format of this output is -covered in the {ref}`export.ir_spec`. +covered in the {ref}`export IR spec `. + +(non-strict-export)= ### Strict vs. Non-Strict Tracing @@ -120,6 +122,9 @@ Whether a value is static or dynamic depends on its type: - There are dynamic variants for some primitive types (`SymInt`, `SymFloat`, `SymBool`). Typically users do not have to deal with them. + - Users can specify integer inputs as dynamic by specifying + a [dynamic shape](https://pytorch.org/docs/main/export.html#expressing-dynamism) + for it. - For Python *standard containers* (`list`, `tuple`, `dict`, `namedtuple`): @@ -150,7 +155,7 @@ By default, the types of inputs you can use for your program are: - Python primitives (`int`, `float`, `bool`, `str`, `None`) - Python standard containers (`list`, `tuple`, `dict`, `namedtuple`) -### Custom Input Types +### Custom Input Types (PyTree) In addition, you can also define your own (custom) class and use it as an input type, but you will need to register such a class as a PyTree. @@ -164,7 +169,8 @@ class Input: f: torch.Tensor p: torch.Tensor -torch.export.register_dataclass(Input) +import torch.utils._pytree as pytree +pytree.register_dataclass(Input) class M(torch.nn.Module): def forward(self, x: Input): diff --git a/docs/source/export/pt2_archive.md b/docs/source/export/pt2_archive.md new file mode 100644 index 0000000000000..cfb589f7bdfe4 --- /dev/null +++ b/docs/source/export/pt2_archive.md @@ -0,0 +1,122 @@ +(export.pt2_archive)= + +# PT2 Archive Spec + +The following specification defines the archive format which can be produced +through the following methods: + +* {ref}`torch.export ` through calling {func}`torch.export.save` +* {ref}`AOTInductor ` through calling {func}`torch._inductor.aoti_compile_and_package` + +The archive is a zipfile, and can be manipulated using standard zipfile APIs. + +The following is a sample archive. We will walk through the archive folder by folder. + +``` +. +├── archive_format +├── byteorder +├── .data +│ ├── serialization_id +│ └── version +├── data +│ ├── aotinductor +│ │ └── model1 +│ │ ├── aotinductor_pickle_data.json +│ │ ├── cf5ez6ifexr7i2hezzz4s7xfusj4wtisvu2gddeamh37bw6bghjw.cpp +│ │ ├── cf5ez6ifexr7i2hezzz4s7xfusj4wtisvu2gddeamh37bw6bghjw.so +│ │ ├── cg7domx3woam3nnliwud7yvtcencqctxkvvcafuriladwxw4nfiv.cubin +│ │ └── cubaaxppb6xmuqdm4bej55h2pftbce3bjyyvljxbtdfuolmv45ex.cubin +│ ├── weights +│ │ ├── model1_model_param_config.json +│ │ ├── weight_0 +│ │ ├── weight_1 +│ │ ├── weight_2 +│ └── constants +│ │ ├── model1_model_constants_config.json +│ │ ├── tensor_0 +│ │ ├── tensor_1 +│ │ ├── custom_obj_0 +│ │ ├── custom_obj_1 +│ └── sample_inputs +│ ├── model1.pt +│ └── model2.pt +├── extra +│ └── ....json +└── models + ├── model1.json + └── model2.json +``` + +## Contents + +### Archive Headers + +* `archive_format` declares the format used by this archive. Currently, it can only be “pt2”. +* `byteorder`. One of “little” or “big”, used by zip file reader +* `/.data/version` contains the archive version. (Notice that this is neither export serialization’s schema version, nor Aten Opset Version). +* `/.data/serialization_id` is a hash generated for the current archive, used for verification. + + +### AOTInductor Compiled Artifact + +Path: `/data/aotinductor/-/` + +AOTInductor compilation artifacts are saved for each model-backend pair. For +example, compilation artifacts for the `model1` model on A100 and H100 will be +saved in `model1-a100` and `model1-h100` folders separately. + +The folder typically contains +* `.so`: Dynamic library compiled from .cpp. +* `.cpp`: AOTInductor generated cpp wrapper file. +* `*.cubin`: Triton kernels compiled from triton codegen kernels +* (optional) `.json`: External fallback nodes for custom ops to be executed by `ProxyExecutor`, serialized according to `ExternKernelNode` struct. If the model doesn’t use custom ops/ProxyExecutor, this file would be omitted. +* `_metadata.json`: Metadata which was passed in from the `aot_inductor.metadata` inductor config + +### Weights + +Path: `/data/weights/*` + +Model parameters and buffers are saved in the `/data/weights/` folder. Each +tensor is saved as a separated file. The file only contains the raw data blob, +tensor metadata are saved separately in the +`_model_param_config.json`. + +### Constants + +Path: `/data/constants/*` + +TensorConstants, non-persistent buffers and TorchBind objects are saved in the +`/data/constants/` folder. Metadata is saved separately in the +`_model_constants_config.json` + +### Sample Inputs + +Path: `/data/sample_inputs/.pt` + +The `sample_input` used by `torch.export` could be included in the archive for +downstream use. Typically, it’s a flattened list of Tensors, combining both args +and kwargs of the forward() function. + +The .pt file is produced by `torch.save(sample_input)`, and can be loaded by +`torch.load()` in python and `torch::pickle_load()` in c++. + +When the model has multiple copies of sample input, it would be packaged as +`_.pt`. + +### Models Definitions + +Path: `/models/.json` + +Model definition is the serialized json of the ExportedProgram from +`torch.export.save`, and other model-level metadata. + +## Multiple Models + +This archive spec supports multiple model definitions coexisting in the same +file, with `` serving as a unique identifier for the models, and +will be used as reference in other folders of the archive. + +Lower level APIs like {func}`torch.export.pt2_archive._package.package_pt2` and +{func}`torch.export.pt2_archive._package.load_pt2` allow you to have +finer-grained control over the packaging and loading process. diff --git a/docs/source/torch.compiler_aot_inductor.md b/docs/source/torch.compiler_aot_inductor.md index d2a7c93392647..d8514a920848e 100644 --- a/docs/source/torch.compiler_aot_inductor.md +++ b/docs/source/torch.compiler_aot_inductor.md @@ -1,3 +1,5 @@ +(torch.compiler_aot_inductor)= + # AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models ```{warning} @@ -25,7 +27,7 @@ relies on. We will then use {func}`torch._inductor.aoti_compile_and_package` to compile the exported program using TorchInductor, and save the compiled artifacts into one -package. +package. The package is in the format of a {ref}`PT2 Archive Spec `. ```{note} If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, diff --git a/docs/source/torch.compiler_ir.md b/docs/source/torch.compiler_ir.md index ed920a064a68d..ff66b8cc7efce 100644 --- a/docs/source/torch.compiler_ir.md +++ b/docs/source/torch.compiler_ir.md @@ -1,3 +1,5 @@ +(torch.compiler_ir)= + # IRs PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR. diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index f951b5818afd1..ccc3660f7600c 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -85,15 +85,19 @@ def __call__(self, min=None, max=None) -> "_DimHint": class Dim: """ - The `Dim` class allows users to specify dynamism in their exported programs. By marking a dimension with a `Dim`, - the compiler associates the dimension with a symbolic integer containing a dynamic range. + The ``Dim`` class allows users to specify dynamism in their exported + programs. By marking a dimension with a ``Dim``, the compiler associates the + dimension with a symbolic integer containing a dynamic range. - The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: `Dim.AUTO`, `Dim.DYNAMIC`, `Dim.STATIC`), - or named Dims (i.e. `Dim("name", min=1, max=2)`). + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: + ``Dim.AUTO``, ``Dim.DYNAMIC``, ``Dim.STATIC``), or named Dims (i.e. + ``Dim("name", min=1, max=2)``). - Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension - if dynamic, static, or left for the compiler to decide (`Dim.AUTO`). The export process will automatically - infer the remaining constraints on min/max ranges and relationships between dimensions. + Dim hints provide the lowest barrier to exportability, with the user only + needing to specify if a dimension if dynamic, static, or left for the + compiler to decide (``Dim.AUTO``). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between + dimensions. Example:: @@ -112,19 +116,19 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) - Here, export would raise an exception if we replaced all uses of `Dim.AUTO` with `Dim.DYNAMIC`, - as x.shape[0] is constrained to be static by the model. + Here, export would raise an exception if we replaced all uses of ``Dim.AUTO`` with ``Dim.DYNAMIC``, + as ``x.shape[0]`` is constrained to be static by the model. More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, - e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints. + e.g. ``(x.shape[0] + y.shape[1]) % 4 == 0``, to be raised if runtime inputs do not satisfy such constraints. - You may also specify min-max bounds for Dim hints, e.g. `Dim.AUTO(min=16, max=32)`, `Dim.DYNAMIC(max=64)`, + You may also specify min-max bounds for Dim hints, e.g. ``Dim.AUTO(min=16, max=32)``, ``Dim.DYNAMIC(max=64)``, with the compiler inferring the remaining constraints within the ranges. An exception will be raised if the valid range is entirely outside the user-specified range. Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler infers constraints that do not match the user specification. For example, exporting the previous - model, the user would need the following `dynamic_shapes` argument:: + model, the user would need the following ``dynamic_shapes`` argument:: s0 = Dim("s0") s1 = Dim("s1", min=16) @@ -134,8 +138,9 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) - Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. - For example, the following indicates one dimension is a multiple of another plus 4:: + Named Dims also allow specification of relationships between dimensions, up + to univariate linear relations. For example, the following indicates one + dimension is a multiple of another plus 4:: s0 = Dim("s0") s1 = 3 * s0 + 4 diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 7c97e6abe171c..f14087250d526 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -12,8 +12,8 @@ import torch import torch.utils._pytree as pytree from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs -from torch.export.exported_program import ExportedProgram from torch.export.pt2_archive._package_weights import ( get_complete, group_weights, @@ -350,22 +350,21 @@ def package_pt2( opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> FileLike: - """ - Saves the artifacts to a PT2Archive format - (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). - The artifact can then be loaded using ``load_pt2``. + r""" + Saves the artifacts to a PT2Archive format. The artifact can then be loaded + using ``load_pt2``. Args: - f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to implement write and flush) or a string containing a file name. exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): The exported program to save, or a dictionary mapping model name to an exported program to save. The exported program will be saved under - models/*.json. If only one ExportedProgram is specified, this will + models/\*.json. If only one ExportedProgram is specified, this will automatically be named "model". - aoti_files (Union[list[str], dict[str, list[str]]): A list of files + aoti_files (Union[list[str], dict[str, list[str]]]): A list of files generated by AOTInductor via ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, or a dictionary mapping model name to its AOTInductor generated files. diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 1a3cf9610a6de..3a741778d0d41 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -15,10 +15,10 @@ import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree from torch._library.fake_class_registry import FakeScriptObject +from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs from torch.export.exported_program import ( ConstantArgument, - ExportedProgram, ExportGraphSignature, InputKind, ModuleCallSignature, From 14ecc0336185f2ca5591858bc74cd4aadf2d1161 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Jul 2025 19:55:21 +0000 Subject: [PATCH 134/457] Revert "recovering node source from dict (#158373)" This reverts commit 4d055982e38f59fdb2a4c9d8855e58548bc42c12. Reverted https://github.com/pytorch/pytorch/pull/158373 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158373#issuecomment-3080093479)) --- test/fx/test_fx_traceback.py | 28 +++++------------------ torch/fx/traceback.py | 43 ------------------------------------ 2 files changed, 6 insertions(+), 65 deletions(-) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 05369d17078ba..f02bc5a2e1592 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -32,8 +32,6 @@ def test_node_source(self): dummy_source_dict, ) - self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict())) - # Dummy node node = torch.fx.Node( graph=torch.fx.Graph(), @@ -181,28 +179,14 @@ def forward(self, x): if node_name_1 in same_ancestor_nodes else None, }: - self.assertEqual( - node_name_to_from_node[node_name_1], - node_name_to_from_node[node_name_2], - ) - self.assertEqual( - [ - NodeSource._from_dict(ns.to_dict()) - for ns in node_name_to_from_node[node_name_1] - ], - node_name_to_from_node[node_name_2], + self.assertTrue( + node_name_to_from_node[node_name_1] + == node_name_to_from_node[node_name_2] ) else: - self.assertNotEqual( - node_name_to_from_node[node_name_1], - node_name_to_from_node[node_name_2], - ) - self.assertNotEqual( - [ - NodeSource._from_dict(ns.to_dict()) - for ns in node_name_to_from_node[node_name_1] - ], - node_name_to_from_node[node_name_2], + self.assertTrue( + node_name_to_from_node[node_name_1] + != node_name_to_from_node[node_name_2] ) gm = ep.module() diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index bcf759c3db4c5..836b41d661859 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -152,49 +152,6 @@ def _make_hashable(obj): return hash(_make_hashable(self.to_dict())) - @classmethod - def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: - """ - Recursively deserialize from_node metadata from dictionary data. - It is used to deserialize the from_node field from serialized metadata. - Please use contructor NodeSource(node, ...) to create a NodeSource object. - """ - if d is None: - return None - - assert isinstance(d, dict), f"Expected a dict, got {type(d)}" - - # Create a NodeSource object directly without going through the constructor - # to avoid issues with graph ID and node creation - node_source = NodeSource.__new__(NodeSource) - - # Set the basic attributes - node_source.pass_name = d.get("pass_name", "") - - # Parse action string back to NodeSourceAction enum list - action_str = d.get("action", "") - actions = [] - if action_str: - for action_name in action_str.split("+"): - if action_name.upper() == "CREATE": - actions.append(NodeSourceAction.CREATE) - elif action_name.upper() == "REPLACE": - actions.append(NodeSourceAction.REPLACE) - node_source.action = actions - - # Create the NodeInfo object directly - if "name" in d and "target" in d and "graph_id" in d: - node_info = NodeSource.NodeInfo( - d.get("name", ""), d.get("target", ""), d.get("graph_id", -1) - ) - node_source.node_info = node_info - else: - node_source.node_info = None - - # Recursively deserialize nested from_node - node_source.from_node = [cls._from_dict(fn) for fn in d.get("from_node", [])] - return node_source - @compatibility(is_backward_compatible=False) @contextmanager From a9ee4250d55c6342b80e2d57a8ad9a1992ddcdce Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Wed, 16 Jul 2025 20:01:30 +0000 Subject: [PATCH 135/457] [4/n] Remove references to TorchScript in PyTorch docs (#158317) Summary: jit.rst Test Plan: CI Rollback Plan: Differential Revision: D78309840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158317 Approved by: https://github.com/svekars, https://github.com/zhxchen17 --- docs/source/jit.rst | 849 +------------------------- docs/source/jit_builtin_functions.rst | 2 - docs/source/jit_unsupported.md | 4 - torch/jit/_script.py | 8 +- 4 files changed, 9 insertions(+), 854 deletions(-) diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 31c5c4dbf8249..5295f82f9ac19 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -1,48 +1,21 @@ TorchScript =========== -.. toctree:: - :maxdepth: 1 - :caption: Builtin Functions - :hidden: - - torch.jit.supported_ops - - .. toctree:: :maxdepth: 1 - :caption: Language Reference :hidden: jit_language_reference - - -.. toctree:: - :maxdepth: 1 - jit_language_reference_v2 + jit_python_reference + jit_unsupported + torch.jit.supported_ops - -.. contents:: :local: - :depth: 2 +.. warning:: + TorchScript is deprecated, please use + `torch.export `__ instead. .. automodule:: torch.jit -.. currentmodule:: torch.jit - -TorchScript is a way to create serializable and optimizable models from PyTorch code. -Any TorchScript program can be saved from a Python -process and loaded in a process where there is no Python dependency. - -We provide tools to incrementally transition a model from a pure Python program -to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. -This makes it possible to train models in PyTorch using familiar tools in Python and then export -the model via TorchScript to a production environment where Python programs may be disadvantageous -for performance and multi-threading reasons. - -For a gentle introduction to TorchScript, see the `Introduction to TorchScript `_ tutorial. - -For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the -`Loading a PyTorch Model in C++ `_ tutorial. Creating TorchScript Code -------------------------- @@ -74,817 +47,11 @@ Creating TorchScript Code Attribute annotate -Mixing Tracing and Scripting ----------------------------- - -In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. -Tracing and scripting can be composed to suit the particular requirements -of a part of a model. - -Scripted functions can call traced functions. This is particularly useful when you need -to use control-flow around a simple feed-forward model. For instance the beam search -of a sequence to sequence model will typically be written in script but can call an -encoder module generated using tracing. - - -.. testsetup:: - - # These are hidden from the docs, but these are necessary for `doctest` - # since the `inspect` module doesn't play nicely with the execution - # environment for `doctest` - import torch - - original_script = torch.jit.script - def script_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_script(obj, *args, **kwargs) - - torch.jit.script = script_wrapper - - original_trace = torch.jit.trace - def trace_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_trace(obj, *args, **kwargs) - - torch.jit.trace = trace_wrapper - - -Example (calling a traced function in script): - -.. testcode:: - - import torch - - def foo(x, y): - return 2 * x + y - - traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) - - @torch.jit.script - def bar(x): - return traced_foo(x, x) - -Traced functions can call script functions. This is useful when a small part of -a model requires some control-flow even though most of the model is just a feed-forward -network. Control-flow inside of a script function called by a traced function is -preserved correctly. - -Example (calling a script function in a traced function): - -.. testcode:: - - import torch - - @torch.jit.script - def foo(x, y): - if x.max() > y.max(): - r = x - else: - r = y - return r - - - def bar(x, y, z): - return foo(x, y) + z - - traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))) - -This composition also works for ``nn.Module``\s as well, where it can be used to generate -a submodule using tracing that can be called from the methods of a script module. - -Example (using a traced module): - -.. testcode:: - :skipif: torchvision is None - - import torch - import torchvision - - class MyScriptModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)) - self.resnet = torch.jit.trace(torchvision.models.resnet18(), - torch.rand(1, 3, 224, 224)) - - def forward(self, input): - return self.resnet(input - self.means) - - my_script_module = torch.jit.script(MyScriptModule()) - - -TorchScript Language --------------------- - -TorchScript is a statically typed subset of Python, so many Python features apply -directly to TorchScript. See the full :ref:`language-reference` for details. - - -.. _builtin functions: - -Built-in Functions and Modules ------------------------------- - -TorchScript supports the use of most PyTorch functions and many Python built-ins. -See :ref:`builtin-functions` for a full reference of supported functions. - -PyTorch Functions and Modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchScript supports a subset of the tensor and neural network -functions that PyTorch provides. Most methods on Tensor as well as functions in -the ``torch`` namespace, all functions in ``torch.nn.functional`` and -most modules from ``torch.nn`` are supported in TorchScript. - -See :ref:`jit_unsupported` for a list of unsupported PyTorch functions and modules. - - -Python Functions and Modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Many of Python's `built-in functions `_ are supported in TorchScript. -The :any:`math` module is also supported, but no other Python modules -(built-in or third party) are supported. - - -Python Language Reference Comparison -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For a full listing of supported Python features, see :ref:`python-language-reference`. - -Debugging ---------- - -.. _`disable TorchScript`: - -Disable JIT for Debugging -~~~~~~~~~~~~~~~~~~~~~~~~~ -.. envvar:: PYTORCH_JIT - -Setting the environment variable ``PYTORCH_JIT=0`` will disable all script -and tracing annotations. If there is hard-to-debug error in one of your -TorchScript models, you can use this flag to force everything to run using native -Python. Since TorchScript (scripting and tracing) is disabled with this flag, -you can use tools like ``pdb`` to debug the model code. For example:: - - @torch.jit.script - def scripted_fn(x : torch.Tensor): - for i in range(12): - x = x + x - return x - - def fn(x): - x = torch.neg(x) - import pdb; pdb.set_trace() - return scripted_fn(x) - - traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),)) - traced_fn(torch.rand(3, 4)) - -Debugging this script with ``pdb`` works except for when we invoke the -:func:`@torch.jit.script ` function. We can globally disable -JIT, so that we can call the :func:`@torch.jit.script ` -function as a normal Python function and not compile it. If the above script -is called ``disable_jit_example.py``, we can invoke it like so:: - - $ PYTORCH_JIT=0 python disable_jit_example.py - -and we will be able to step into the :func:`@torch.jit.script -` function as a normal Python function. To disable the -TorchScript compiler for a specific function, see -:func:`@torch.jit.ignore `. - -.. _inspecting-code: - -Inspecting Code -~~~~~~~~~~~~~~~ - -TorchScript provides a code pretty-printer for all :class:`ScriptModule` instances. This -pretty-printer gives an interpretation of the script method's code as valid -Python syntax. For example: - -.. testcode:: - - @torch.jit.script - def foo(len): - # type: (int) -> torch.Tensor - rv = torch.zeros(3, 4) - for i in range(len): - if i < 10: - rv = rv - 1.0 - else: - rv = rv + 1.0 - return rv - - print(foo.code) - -.. testoutput:: - :hide: - - ... - -A :class:`ScriptModule` with a single ``forward`` method will have an attribute -``code``, which you can use to inspect the :class:`ScriptModule`'s code. -If the :class:`ScriptModule` has more than one method, you will need to access -``.code`` on the method itself and not the module. We can inspect the -code of a method named ``foo`` on a :class:`ScriptModule` by accessing ``.foo.code``. -The example above produces this output: :: - - def foo(len: int) -> Tensor: - rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None) - rv0 = rv - for i in range(len): - if torch.lt(i, 10): - rv1 = torch.sub(rv0, 1., 1) - else: - rv1 = torch.add(rv0, 1., 1) - rv0 = rv1 - return rv0 - -This is TorchScript's compilation of the code for the ``forward`` method. -You can use this to ensure TorchScript (tracing or scripting) has captured -your model code correctly. - - -.. _interpreting-graphs: - -Interpreting Graphs -~~~~~~~~~~~~~~~~~~~ -TorchScript also has a representation at a lower level than the code pretty-\ -printer, in the form of IR graphs. - -TorchScript uses a static single assignment (SSA) intermediate representation -(IR) to represent computation. The instructions in this format consist of -ATen (the C++ backend of PyTorch) operators and other primitive operators, -including control flow operators for loops and conditionals. As an example: - -.. testcode:: - - @torch.jit.script - def foo(len): - # type: (int) -> torch.Tensor - rv = torch.zeros(3, 4) - for i in range(len): - if i < 10: - rv = rv - 1.0 - else: - rv = rv + 1.0 - return rv - - print(foo.graph) - -.. testoutput:: - :hide: - - ... - -``graph`` follows the same rules described in the :ref:`inspecting-code` section -with regard to ``forward`` method lookup. - -The example script above produces the graph:: - - graph(%len.1 : int): - %24 : int = prim::Constant[value=1]() - %17 : bool = prim::Constant[value=1]() # test.py:10:5 - %12 : bool? = prim::Constant() - %10 : Device? = prim::Constant() - %6 : int? = prim::Constant() - %1 : int = prim::Constant[value=3]() # test.py:9:22 - %2 : int = prim::Constant[value=4]() # test.py:9:25 - %20 : int = prim::Constant[value=10]() # test.py:11:16 - %23 : float = prim::Constant[value=1]() # test.py:12:23 - %4 : int[] = prim::ListConstruct(%1, %2) - %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 - %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5 - block0(%i.1 : int, %rv.14 : Tensor): - %21 : bool = aten::lt(%i.1, %20) # test.py:11:12 - %rv.13 : Tensor = prim::If(%21) # test.py:11:9 - block0(): - %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18 - -> (%rv.3) - block1(): - %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18 - -> (%rv.6) - -> (%17, %rv.13) - return (%rv) - - -Take the instruction ``%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`` for -example. - -* ``%rv.1 : Tensor`` means we assign the output to a (unique) value named ``rv.1``, that value is of ``Tensor`` type and that we do not know its concrete shape. -* ``aten::zeros`` is the operator (equivalent to ``torch.zeros``) and the input list ``(%4, %6, %6, %10, %12)`` specifies which values in scope should be passed as inputs. The schema for built-in functions like ``aten::zeros`` can be found at `Builtin Functions`_. -* ``# test.py:9:10`` is the location in the original source file that generated this instruction. In this case, it is a file named `test.py`, on line 9, and at character 10. - -Notice that operators can also have associated ``blocks``, namely the -``prim::Loop`` and ``prim::If`` operators. In the graph print-out, these -operators are formatted to reflect their equivalent source code forms -to facilitate easy debugging. - -Graphs can be inspected as shown to confirm that the computation described -by a :class:`ScriptModule` is correct, in both automated and manual fashion, as -described below. - -Tracer -~~~~~~ - - -Tracing Edge Cases -^^^^^^^^^^^^^^^^^^ -There are some edge cases that exist where the trace of a given Python -function/module will not be representative of the underlying code. These -cases can include: - -* Tracing of control flow that is dependent on inputs (e.g. tensor shapes) -* Tracing of in-place operations of tensor views (e.g. indexing on the left-hand side of an assignment) - -Note that these cases may in fact be traceable in the future. - - -Automatic Trace Checking -^^^^^^^^^^^^^^^^^^^^^^^^ -One way to automatically catch many errors in traces is by using ``check_inputs`` -on the ``torch.jit.trace()`` API. ``check_inputs`` takes a list of tuples -of inputs that will be used to re-trace the computation and verify the -results. For example:: - - def loop_in_traced_fn(x): - result = x[0] - for i in range(x.size(0)): - result = result * x[i] - return result - - inputs = (torch.rand(3, 4, 5),) - check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] - - traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs) - -Gives us the following diagnostic information:: - - ERROR: Graphs differed across invocations! - Graph diff: - - graph(%x : Tensor) { - %1 : int = prim::Constant[value=0]() - %2 : int = prim::Constant[value=0]() - %result.1 : Tensor = aten::select(%x, %1, %2) - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=0]() - %6 : Tensor = aten::select(%x, %4, %5) - %result.2 : Tensor = aten::mul(%result.1, %6) - %8 : int = prim::Constant[value=0]() - %9 : int = prim::Constant[value=1]() - %10 : Tensor = aten::select(%x, %8, %9) - - %result : Tensor = aten::mul(%result.2, %10) - + %result.3 : Tensor = aten::mul(%result.2, %10) - ? ++ - %12 : int = prim::Constant[value=0]() - %13 : int = prim::Constant[value=2]() - %14 : Tensor = aten::select(%x, %12, %13) - + %result : Tensor = aten::mul(%result.3, %14) - + %16 : int = prim::Constant[value=0]() - + %17 : int = prim::Constant[value=3]() - + %18 : Tensor = aten::select(%x, %16, %17) - - %15 : Tensor = aten::mul(%result, %14) - ? ^ ^ - + %19 : Tensor = aten::mul(%result, %18) - ? ^ ^ - - return (%15); - ? ^ - + return (%19); - ? ^ - } - - -This message indicates to us that the computation differed between when -we first traced it and when we traced it with the ``check_inputs``. Indeed, -the loop within the body of ``loop_in_traced_fn`` depends on the shape -of the input ``x``, and thus when we try another ``x`` with a different -shape, the trace differs. - -In this case, data-dependent control flow like this can be captured using -:func:`torch.jit.script` instead: - -.. testcode:: - - def fn(x): - result = x[0] - for i in range(x.size(0)): - result = result * x[i] - return result - - inputs = (torch.rand(3, 4, 5),) - check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] - - scripted_fn = torch.jit.script(fn) - print(scripted_fn.graph) - #print(str(scripted_fn.graph).strip()) - - for input_tuple in [inputs] + check_inputs: - torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple)) - -.. testoutput:: - :hide: - - ... - - -Which produces:: - - graph(%x : Tensor) { - %5 : bool = prim::Constant[value=1]() - %1 : int = prim::Constant[value=0]() - %result.1 : Tensor = aten::select(%x, %1, %1) - %4 : int = aten::size(%x, %1) - %result : Tensor = prim::Loop(%4, %5, %result.1) - block0(%i : int, %7 : Tensor) { - %10 : Tensor = aten::select(%x, %1, %i) - %result.2 : Tensor = aten::mul(%7, %10) - -> (%5, %result.2) - } - return (%result); - } - -Tracer Warnings -^^^^^^^^^^^^^^^ -The tracer produces warnings for several problematic patterns in traced -computation. As an example, take a trace of a function that contains an -in-place assignment on a slice (a view) of a Tensor: - -.. testcode:: - - def fill_row_zero(x): - x[0] = torch.rand(*x.shape[1:2]) - return x - - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - print(traced.graph) - -.. testoutput:: - :hide: - - ... - -Produces several warnings and a graph which simply returns the input:: - - fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe. - x[0] = torch.rand(*x.shape[1:2]) - fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: - Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%) - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - graph(%0 : Float(3, 4)) { - return (%0); - } - -We can fix this by modifying the code to not use the in-place update, but -rather build up the result tensor out-of-place with ``torch.cat``: - -.. testcode:: - - def fill_row_zero(x): - x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0) - return x - - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - print(traced.graph) - -.. testoutput:: - :hide: - - ... - -Frequently Asked Questions --------------------------- - -Q: I would like to train a model on GPU and do inference on CPU. What are the -best practices? - - First convert your model from GPU to CPU and then save it, like so: :: - - cpu_model = gpu_model.cpu() - sample_input_cpu = sample_input_gpu.cpu() - traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) - torch.jit.save(traced_cpu, "cpu.pt") - - traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) - torch.jit.save(traced_gpu, "gpu.pt") - - # ... later, when using the model: - - if use_gpu: - model = torch.jit.load("gpu.pt") - else: - model = torch.jit.load("cpu.pt") - - model(input) - - This is recommended because the tracer may witness tensor creation on a - specific device, so casting an already-loaded model may have unexpected - effects. Casting the model *before* saving it ensures that the tracer has - the correct device information. - - -Q: How do I store attributes on a :class:`ScriptModule`? - - Say we have a model like: - - .. testcode:: - - import torch - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.x = 2 - - def forward(self): - return self.x - - m = torch.jit.script(Model()) - - - - If ``Model`` is instantiated it will result in a compilation error - since the compiler doesn't know about ``x``. There are 4 ways to inform the - compiler of attributes on :class:`ScriptModule`: - - 1. ``nn.Parameter`` - Values wrapped in ``nn.Parameter`` will work as they - do on ``nn.Module``\s - - 2. ``register_buffer`` - Values wrapped in ``register_buffer`` will work as - they do on ``nn.Module``\s. This is equivalent to an attribute (see 4) of type - ``Tensor``. - - 3. Constants - Annotating a class member as ``Final`` (or adding it to a list called - ``__constants__`` at the class definition level) will mark the contained names - as constants. Constants are saved directly in the code of the model. See - `builtin-constants` for details. - - 4. Attributes - Values that are a `supported type` can be added as mutable - attributes. Most types can be inferred but some may need to be specified, see - `module attributes` for details. - -Q: I would like to trace module's method but I keep getting this error: - -``RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient`` - - This error usually means that the method you are tracing uses a module's parameters and - you are passing the module's method instead of the module instance (e.g. ``my_module_instance.forward`` vs ``my_module_instance``). - - - Invoking ``trace`` with a module's method captures module parameters (which may require gradients) as **constants**. - - On the other hand, invoking ``trace`` with module's instance (e.g. ``my_module``) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required. - - To trace a specific method on a module, see :func:`torch.jit.trace_module ` - -Known Issues ---------------- - -If you're using ``Sequential`` with TorchScript, the inputs of some -of the ``Sequential`` submodules may be falsely inferred to be -``Tensor``, even if they're annotated otherwise. The canonical -solution is to subclass ``nn.Sequential`` and redeclare ``forward`` -with the input typed correctly. - -Appendix --------- - -Migrating to PyTorch 1.2 Recursive Scripting API -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This section details the changes to TorchScript in PyTorch 1.2. If you are new to TorchScript you can -skip this section. There are two main changes to the TorchScript API with PyTorch 1.2. - -1. :func:`torch.jit.script ` will now attempt to recursively compile functions, -methods, and classes that it encounters. Once you call ``torch.jit.script``, -compilation is "opt-out", rather than "opt-in". - -2. ``torch.jit.script(nn_module_instance)`` is now the preferred way to create -:class:`ScriptModule`\s, instead of inheriting from ``torch.jit.ScriptModule``. -These changes combine to provide a simpler, easier-to-use API for converting -your ``nn.Module``\s into :class:`ScriptModule`\s, ready to be optimized and executed in a -non-Python environment. - -The new usage looks like this: - -.. testcode:: - - import torch - import torch.nn as nn - import torch.nn.functional as F - - class Model(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 20, 5) - self.conv2 = nn.Conv2d(20, 20, 5) - - def forward(self, x): - x = F.relu(self.conv1(x)) - return F.relu(self.conv2(x)) - - my_model = Model() - my_scripted_model = torch.jit.script(my_model) - - -* The module's ``forward`` is compiled by default. Methods called from ``forward`` are lazily compiled in the order they are used in ``forward``. -* To compile a method other than ``forward`` that is not called from ``forward``, add ``@torch.jit.export``. -* To stop the compiler from compiling a method, add :func:`@torch.jit.ignore ` or :func:`@torch.jit.unused `. ``@ignore`` leaves the -* method as a call to python, and ``@unused`` replaces it with an exception. ``@ignored`` cannot be exported; ``@unused`` can. -* Most attribute types can be inferred, so ``torch.jit.Attribute`` is not necessary. For empty container types, annotate their types using `PEP 526-style `_ class annotations. -* Constants can be marked with a ``Final`` class annotation instead of adding the name of the member to ``__constants__``. -* Python 3 type hints can be used in place of ``torch.jit.annotate`` - -As a result of these changes, the following items are considered deprecated and should not appear in new code: - * The ``@torch.jit.script_method`` decorator - * Classes that inherit from ``torch.jit.ScriptModule`` - * The ``torch.jit.Attribute`` wrapper class - * The ``__constants__`` array - * The ``torch.jit.annotate`` function - -Modules -^^^^^^^ -.. warning:: - - The :func:`@torch.jit.ignore ` annotation's behavior changes in - PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function - or method callable from code that is exported. To get this functionality back, - use ``@torch.jit.unused()``. ``@torch.jit.ignore`` is now equivalent - to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore ` - and :func:`@torch.jit.unused` for details. - -When passed to the :func:`torch.jit.script ` function, a ``torch.nn.Module``\'s data is -copied to a :class:`ScriptModule` and the TorchScript compiler compiles the module. -The module's ``forward`` is compiled by default. Methods called from ``forward`` are -lazily compiled in the order they are used in ``forward``, as well as any -``@torch.jit.export`` methods. - -.. autofunction:: export - -Functions -^^^^^^^^^ -Functions don't change much, they can be decorated with :func:`@torch.jit.ignore ` or :func:`torch.jit.unused ` if needed. - -.. testcode:: - - # Same behavior as pre-PyTorch 1.2 - @torch.jit.script - def some_fn(): - return 2 - - # Marks a function as ignored, if nothing - # ever calls it then this has no effect - @torch.jit.ignore - def some_fn2(): - return 2 - - # As with ignore, if nothing calls it then it has no effect. - # If it is called in script it is replaced with an exception. - @torch.jit.unused - def some_fn3(): - import pdb; pdb.set_trace() - return 4 - - # Doesn't do anything, this function is already - # the main entry point - @torch.jit.export - def some_fn4(): - return 2 - -TorchScript Classes -^^^^^^^^^^^^^^^^^^^ - -.. warning:: - - TorchScript class support is experimental. Currently it is best suited - for simple record-like types (think a ``NamedTuple`` with methods - attached). - -Everything in a user defined `TorchScript Class `_ is -exported by default, functions can be decorated with :func:`@torch.jit.ignore -` if needed. - -Attributes -^^^^^^^^^^ -The TorchScript compiler needs to know the types of `module attributes`. Most types -can be inferred from the value of the member. Empty lists and dicts cannot have their -types inferred and must have their types annotated with `PEP 526-style `_ class annotations. -If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute -to the resulting :class:`ScriptModule` - - -Old API: - -.. testcode:: - - from typing import Dict - import torch - - class MyModule(torch.jit.ScriptModule): - def __init__(self): - super().__init__() - self.my_dict = torch.jit.Attribute({}, Dict[str, int]) - self.my_int = torch.jit.Attribute(20, int) - - m = MyModule() - -New API: - -.. testcode:: - - from typing import Dict - - class MyModule(torch.nn.Module): - my_dict: Dict[str, int] - - def __init__(self): - super().__init__() - # This type cannot be inferred and must be specified - self.my_dict = {} - - # The attribute type here is inferred to be `int` - self.my_int = 20 - - def forward(self): - pass - - m = torch.jit.script(MyModule()) - - -Constants -^^^^^^^^^ -The ``Final`` type constructor can be used to mark members as `constant`. If members are not marked constant, they will be copied to the resulting :class:`ScriptModule` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety. - -Old API: - -.. testcode:: - - class MyModule(torch.jit.ScriptModule): - __constants__ = ['my_constant'] - - def __init__(self): - super().__init__() - self.my_constant = 2 - - def forward(self): - pass - m = MyModule() - -New API: - -:: - - from typing import Final - - class MyModule(torch.nn.Module): - - my_constant: Final[int] - - def __init__(self): - super().__init__() - self.my_constant = 2 - - def forward(self): - pass - - m = torch.jit.script(MyModule()) - -.. _Python 3 type hints: - -Variables -^^^^^^^^^ -Containers are assumed to have type ``Tensor`` and be non-optional (see -`Default Types` for more information). Previously, ``torch.jit.annotate`` was used to -tell the TorchScript compiler what the type should be. Python 3 style type hints are -now supported. - -.. testcode:: - - import torch - from typing import Dict, Optional - - @torch.jit.script - def make_dict(flag: bool): - x: Dict[str, int] = {} - x['hi'] = 2 - b: Optional[int] = None - if flag: - b = 2 - return x, b - -Fusion Backends -~~~~~~~~~~~~~~~ -There are a couple of fusion backends available to optimize TorchScript execution. The default fuser on CPUs is NNC, which can perform fusions for both CPUs and GPUs. The default fuser on GPUs is NVFuser, which supports a wider range of operators and has demonstrated generated kernels with improved throughput. See the `NVFuser documentation `_ for more details on usage and debugging. - - -References -~~~~~~~~~~ -.. toctree:: - :maxdepth: 1 - - jit_python_reference - jit_unsupported .. This package is missing doc. Adding it here for coverage .. This does not add anything to the rendered page. +.. py:module:: torch.jit.supported_ops +.. py:module:: torch.jit.unsupported_tensor_ops .. py:module:: torch.jit.mobile .. py:module:: torch.jit.annotations .. py:module:: torch.jit.frontend diff --git a/docs/source/jit_builtin_functions.rst b/docs/source/jit_builtin_functions.rst index c08e0739266a9..6fd514f6e6fca 100644 --- a/docs/source/jit_builtin_functions.rst +++ b/docs/source/jit_builtin_functions.rst @@ -6,5 +6,3 @@ TorchScript Builtins .. warning:: TorchScript is deprecated, please use `torch.export `__ instead. - -.. automodule:: torch.jit.supported_ops diff --git a/docs/source/jit_unsupported.md b/docs/source/jit_unsupported.md index be3ddaec12a72..bdb930970f510 100644 --- a/docs/source/jit_unsupported.md +++ b/docs/source/jit_unsupported.md @@ -6,7 +6,3 @@ TorchScript is deprecated, please use [torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. ::: - -```{eval-rst} -.. automodule:: torch.jit.unsupported_tensor_ops -``` \ No newline at end of file diff --git a/torch/jit/_script.py b/torch/jit/_script.py index a9a95cdace452..ccd967d69f4e7 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -704,10 +704,7 @@ def _reconstruct(self, cpp_module): @property def graph(self): - r"""Return a string representation of the internal graph for the ``forward`` method. - - See :ref:`interpreting-graphs` for details. - """ + r"""Return a string representation of the internal graph for the ``forward`` method.""" return self._c._get_method("forward").graph @property @@ -716,7 +713,6 @@ def inlined_graph(self): Return a string representation of the internal graph for the ``forward`` method. This graph will be preprocessed to inline all function and method calls. - See :ref:`interpreting-graphs` for details. """ return self.forward.inlined_graph # type: ignore[attr-defined] @@ -725,7 +721,6 @@ def code(self): r""" Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. - See :ref:`inspecting-code` for details. """ return self.forward.code # type: ignore[attr-defined] @@ -740,7 +735,6 @@ def code_with_constants(self): [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant's values. - See :ref:`inspecting-code` for details. """ r = self.forward.code_with_constants # type: ignore[attr-defined] return (r[0], ConstMap(r[1])) From fb731fe371cb1b5bf95de84b19c213590526acb2 Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 16 Jul 2025 20:11:18 +0000 Subject: [PATCH 136/457] Add warning about removed sm50 and sm60 arches (#158301) Related to https://github.com/pytorch/pytorch/issues/157517 Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures. We would like to include reference to the issue: https://github.com/pytorch/pytorch/issues/157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures. Test: ``` >>> torch.cuda.get_arch_list() ['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120'] >>> torch.cuda.init() /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning: Found which is of cuda capability 5.0. PyTorch no longer supports this GPU because it is too old. The minimum cuda capability supported by this library is 7.0. warnings.warn( /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning: Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds. Please see https://github.com/pytorch/pytorch/issues/157517 Please install CUDA 12.6 builds if you require Maxwell or Pascal support. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158301 Approved by: https://github.com/nWEIdia, https://github.com/albanD --- torch/cuda/__init__.py | 56 +++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 5b85c91d2c208..4602aa1ee172e 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -244,21 +244,25 @@ def _extract_arch_version(arch_string: str) -> int: def _check_capability(): - incorrect_binary_warn = """ - Found GPU%d %s which requires CUDA_VERSION >= %d to - work properly, but your PyTorch was compiled - with CUDA_VERSION %d. Please install the correct PyTorch binary - using instructions from https://pytorch.org - """ # noqa: F841 - - old_gpu_warn = """ + incompatible_gpu_warn = """ Found GPU%d %s which is of cuda capability %d.%d. - PyTorch no longer supports this GPU because it is too old. - The minimum cuda capability supported by this library is %d.%d. + Minimum and Maximum cuda capability supported by this version of PyTorch is + (%d.%d) - (%d.%d) """ + matched_cuda_warn = """ + Please install PyTorch with a following CUDA + configurations: {} following instructions at + https://pytorch.org/get-started/locally/ + """ + + # Binary CUDA_ARCHES SUPPORTED by PyTorch + CUDA_ARCHES_SUPPORTED = { + "12.6": {"min": 50, "max": 90}, + "12.8": {"min": 70, "max": 120}, + "12.9": {"min": 70, "max": 120}, + } if torch.version.cuda is not None: # on ROCm we don't want this check - CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841 for d in range(device_count()): capability = get_device_capability(d) major = capability[0] @@ -267,13 +271,35 @@ def _check_capability(): current_arch = major * 10 + minor min_arch = min( (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), - default=35, + default=50, ) - if current_arch < min_arch: + max_arch = max( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=50, + ) + if current_arch < min_arch or current_arch > max_arch: warnings.warn( - old_gpu_warn - % (d, name, major, minor, min_arch // 10, min_arch % 10) + incompatible_gpu_warn + % ( + d, + name, + major, + minor, + min_arch // 10, + min_arch % 10, + max_arch // 10, + max_arch % 10, + ) ) + matched_arches = "" + for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): + if ( + current_arch >= arch_info["min"] + and current_arch <= arch_info["max"] + ): + matched_arches += f" {arch}" + if matched_arches != "": + warnings.warn(matched_cuda_warn.format(matched_arches)) def _check_cubins(): From 473208cb18d543e8f968918a6b3c9defa8a4ae10 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 16 Jul 2025 20:31:07 +0000 Subject: [PATCH 137/457] [ez][lint] Add pr_time_benchmarks to merge conflictless csv linter (#158353) Discovered this when looking at a PR I was trying to revert and was surprised that the PR got rid of the spaces but didn't trigger the linter. Turns out the file was following the rule but wasn't actually being checked Pull Request resolved: https://github.com/pytorch/pytorch/pull/158353 Approved by: https://github.com/seemethere, https://github.com/Camyll --- .lintrunner.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 4da6616e08cff..5638a441b8e58 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1605,7 +1605,10 @@ is_formatter = true # the same line, merge conflicts should not arise in git or hg [[linter]] code = 'MERGE_CONFLICTLESS_CSV' -include_patterns = ['benchmarks/dynamo/ci_expected_accuracy/*.csv'] +include_patterns = [ + 'benchmarks/dynamo/ci_expected_accuracy/*.csv', + 'benchmarks/dynamo/pr_time_benchmarks/expected_results.csv', +] command = [ 'python3', 'tools/linter/adapters/no_merge_conflict_csv_linter.py', From 94c746bb43484787a3f5bbdc2f72bd4fb02f2964 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 15 Jul 2025 23:41:33 -0700 Subject: [PATCH 138/457] [DTensor][BE] add document to ShardingPropagator.register_op_strategy (#158362) **Summary** Add document to `ShardingPropagator.register_op_strategy` on how to draft `strategy_func` and when to use `schema_info`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158362 Approved by: https://github.com/zpcore --- torch/distributed/tensor/_sharding_prop.py | 44 +++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 69af19fea26ab..b05c84de01887 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -101,7 +101,49 @@ def register_op_strategy( schema_info: Optional[RuntimeSchemaInfo] = None, ): """ - Register a sharding strategy generator for an operator. + Register a :class:`OpStrategy` generator for an operator. + + During the sharding propagation, DTensor wants to enumerate all + acceptable sharding specs (:class:`OpSpec`) for an operator, + and by "acceptable" we mean that the operator can be executed on + the ``_local_tensor`` of DTensor args/kwargs (with ``OpSpec.input_specs``) + and the output(s) constitute valid DTensor(s) (with ``OpSpec.output_specs``). + + ``strategy_func`` is the function that enumerates such acceptable specs + for the operator ``op_overload``. One general approach to write ``strategy_func`` + is, if the operator has simple arguments structure (e.g. mm, bmm), first enumerating + all sharding specs for the operands, and then filtering out the ones that + are not valid. For example, for ``mm``, the operands are two 2D tensors, and + if both ``input`` and ``mat2`` have sharding placements ``[Shard(0)]``, then this + is not an acceptable ``input_specs``. + + Once we have a way to enumerate all acceptable sharding specs, we can use each + of them to construct a :class:`OpSpec`. The ``OpSpec.input_specs`` directly comes + from the sharding spec, and the ``OpSpec.output_specs`` is therefore determined + (e.g. ``[Shard(1)]`` @ ``[Shard(0)]`` yields ``[Partial()]``). In addition, + :class:`OpSpec` also contains ``redistribute_cost`` which records the redistribution + cost from each :class:`OpSpec` in the source :class:`OpStrategy.strategies` to + the target sharding spec, for each operand. + + The ``strategy_func`` should return a :class:`OpStrategy` which contains a list of + all the :class:`OpSpec`s generated in the above. + + The optional ``schema_info`` tells which non-DTensor args/kwargs could affect the + cache and whether ``pytree`` is needed to flatten the nested args. ``static_argnum`` + marks the starting index of the non-DTensor args that should be hashed into the + sharding propagation hash key, and ``static_kwargkey`` marks the keys of the + non-DTensor kwargs that should be hashed. ``needs_pytree`` should be used when + the input arg has :class:`list` or :class:`dict` structure. + + For example, ``aten.cat.default`` op has a ``List[Tensor]`` argument ``tensors`` + and an ``int`` argument ``dim``. Because ``dim`` affects the sharding propagation + result, we want to pass ``RuntimeSchemaInfo(static_argnum=1)`` because the argument + index of ``dim`` is 1. Besides, we also want to set ``needs_pytree=True`` because + ``tensors`` needs be flattened in sharding propagation. Another example is + ``aten.histc.default``. ``histc`` has 4 arguments (self, bins, min, max) and the + last two would affect sharding propagation along with the :class:`DTensor` argument + ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be + `RuntimeSchemaInfo(static_argnum=2)`. """ self.op_strategy_funcs[op_overload] = strategy_func if schema_info is not None: From 9df0176408518b30ac172837bd697c9d19b19a98 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 16 Jul 2025 16:09:22 +0000 Subject: [PATCH 139/457] [BE][testing] Disable test_static_cuda_launcher:test_floats internally (#158296) Summary: it seems the check for 'Offd' vs. 'Offf' doesn't work Pull Request resolved: https://github.com/pytorch/pytorch/pull/158296 Approved by: https://github.com/davidberard98 --- test/inductor/test_static_cuda_launcher.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 477d5ac2e6c20..c1af125eb6bc6 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -2,6 +2,7 @@ import os import random import tempfile +import unittest from unittest import mock import torch @@ -12,7 +13,7 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import skipIfRocm +from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm from torch.testing._internal.triton_utils import requires_cuda from torch.torch_version import TorchVersion @@ -144,6 +145,7 @@ def signed_integers( # despite type annotations. # There's also not really a good way for me to make a float16 in python... @skipIfRocm + @unittest.skipIf(IS_FBCODE, "Not working in fbcode") def test_floats(self): @triton.jit def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64): From ada44e5ba78be9377814678d1986556af2d6e570 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 16 Jul 2025 21:50:51 +0000 Subject: [PATCH 140/457] [Dynamo][Better Engineering] Add typing to bytecode analysis and transform (#158293) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical tracing point for dynamo, `bytecode_transformation.py` and by extension, `bytecode_analysis.py` Running ``` mypy torch/_dynamo/bytecode_transformation.py torch/_dynamo/bytecode_analysis.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 1422 | 1920 | 74.06% | 73 | 93 | 78.49% | | This PR | 1968 | 1968 | 100.00% | 93 | 93 | 100.00% | | Delta | +546 | +48 | +25.94% | 20 | 0 | +21.51% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158293 Approved by: https://github.com/StrongerXi, https://github.com/Skylion007 --- torch/_dynamo/bytecode_analysis.py | 47 +++--- torch/_dynamo/bytecode_transformation.py | 174 ++++++++++++++--------- torch/_dynamo/resume_execution.py | 2 +- torch/_dynamo/symbolic_convert.py | 1 + 4 files changed, 136 insertions(+), 88 deletions(-) diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 2de74ee5bf8d2..8bdf155e00603 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for analyzing and optimizing Python bytecode. Key functionality includes: @@ -18,8 +16,13 @@ import dataclasses import dis import sys -from typing import Any, Union +from typing import Any, TYPE_CHECKING, Union + +if TYPE_CHECKING: + # TODO(lucaskabela): consider moving Instruction into this file + # and refactoring in callsite; that way we don't have to guard this import + from .bytecode_transformation import Instruction TERMINAL_OPCODES = { dis.opmap["RETURN_VALUE"], @@ -45,7 +48,7 @@ stack_effect = dis.stack_effect -def get_indexof(insts): +def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]: """ Get a mapping from instruction memory address to index in instruction list. Additionally checks that each instruction only appears once in the list. @@ -57,12 +60,12 @@ def get_indexof(insts): return indexof -def remove_dead_code(instructions): +def remove_dead_code(instructions: list["Instruction"]) -> list["Instruction"]: """Dead code elimination""" indexof = get_indexof(instructions) live_code = set() - def find_live_code(start): + def find_live_code(start: int) -> None: for i in range(start, len(instructions)): if i in live_code: return @@ -71,6 +74,7 @@ def find_live_code(start): if inst.exn_tab_entry: find_live_code(indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None find_live_code(indexof[inst.target]) if inst.opcode in TERMINAL_OPCODES: return @@ -102,7 +106,7 @@ def find_live_code(start): return [inst for i, inst in enumerate(instructions) if i in live_code] -def remove_pointless_jumps(instructions): +def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instruction"]: """Eliminate jumps to the next instruction""" pointless_jumps = { id(a) @@ -112,11 +116,11 @@ def remove_pointless_jumps(instructions): return [inst for inst in instructions if id(inst) not in pointless_jumps] -def propagate_line_nums(instructions): +def propagate_line_nums(instructions: list["Instruction"]) -> None: """Ensure every instruction has line number set in case some are removed""" cur_line_no = None - def populate_line_num(inst): + def populate_line_num(inst: "Instruction") -> None: nonlocal cur_line_no if inst.starts_line: cur_line_no = inst.starts_line @@ -127,12 +131,12 @@ def populate_line_num(inst): populate_line_num(inst) -def remove_extra_line_nums(instructions): +def remove_extra_line_nums(instructions: list["Instruction"]) -> None: """Remove extra starts line properties before packing bytecode""" cur_line_no = None - def remove_line_num(inst): + def remove_line_num(inst: "Instruction") -> None: nonlocal cur_line_no if inst.starts_line is None: return @@ -152,12 +156,14 @@ class ReadsWrites: visited: set[Any] -def livevars_analysis(instructions, instruction): +def livevars_analysis( + instructions: list["Instruction"], instruction: "Instruction" +) -> set[Any]: indexof = get_indexof(instructions) must = ReadsWrites(set(), set(), set()) may = ReadsWrites(set(), set(), set()) - def walk(state, start): + def walk(state: ReadsWrites, start: int) -> None: if start in state.visited: return state.visited.add(start) @@ -177,6 +183,7 @@ def walk(state, start): if inst.exn_tab_entry: walk(may, indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None walk(may, indexof[inst.target]) state = may if inst.opcode in TERMINAL_OPCODES: @@ -197,19 +204,19 @@ class StackSize: high: Union[int, float] fixed_point: FixedPointBox - def zero(self): + def zero(self) -> None: self.low = 0 self.high = 0 self.fixed_point.value = False - def offset_of(self, other, n): + def offset_of(self, other: "StackSize", n: int) -> None: prior = (self.low, self.high) self.low = min(self.low, other.low + n) self.high = max(self.high, other.high + n) if (self.low, self.high) != prior: self.fixed_point.value = False - def exn_tab_jump(self, depth): + def exn_tab_jump(self, depth: int) -> None: prior = (self.low, self.high) self.low = min(self.low, depth) self.high = max(self.high, depth) @@ -217,7 +224,7 @@ def exn_tab_jump(self, depth): self.fixed_point.value = False -def stacksize_analysis(instructions) -> Union[int, float]: +def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]: assert instructions fixed_point = FixedPointBox() stack_sizes = { @@ -238,6 +245,7 @@ def stacksize_analysis(instructions) -> Union[int, float]: eff = stack_effect(inst.opcode, inst.arg, jump=False) stack_sizes[next_inst].offset_of(stack_size, eff) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None, f"missing target: {inst}" stack_sizes[inst.target].offset_of( stack_size, stack_effect(inst.opcode, inst.arg, jump=True) ) @@ -247,11 +255,6 @@ def stacksize_analysis(instructions) -> Union[int, float]: depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) - if False: - for inst in instructions: - stack_size = stack_sizes[inst] - print(stack_size.low, stack_size.high, inst) - low = min(x.low for x in stack_sizes.values()) high = max(x.high for x in stack_sizes.values()) diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 9226a61577d87..165182d93d233 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for analyzing, transforming and manipulating Python bytecode. It includes functionality for: @@ -23,7 +21,7 @@ import sys import types import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import Any, Callable, cast, Optional, Union from ..utils._backport_slots import dataclass_slots @@ -53,7 +51,9 @@ def __repr__(self) -> str: f"depth={self.depth}, lasti={self.lasti})" ) - def __eq__(self, o) -> bool: + def __eq__(self, o: object) -> bool: + if not isinstance(o, InstructionExnTabEntry): + return False return ( self.start is o.start and self.end is o.end @@ -84,7 +84,7 @@ class Instruction: def __hash__(self) -> int: return id(self) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: return id(self) == id(other) def short_inst_repr(self) -> str: @@ -145,22 +145,26 @@ def __repr__(self) -> str: if sys.version_info >= (3, 12): - def inst_has_op_bits(name): + def inst_has_op_bits(name: str) -> bool: return name in ("LOAD_ATTR", "LOAD_GLOBAL", "LOAD_SUPER_ATTR") elif sys.version_info >= (3, 11): - def inst_has_op_bits(name): + def inst_has_op_bits(name: str) -> bool: return name == "LOAD_GLOBAL" else: - def inst_has_op_bits(name): + def inst_has_op_bits(name: str): return False def create_instruction( - name, *, arg=None, argval=_NotProvided, target=None + name: str, + *, + arg: Optional[int] = None, + argval: Optional[Any] = _NotProvided, + target: Optional[Instruction] = None, ) -> Instruction: """ At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. @@ -198,12 +202,12 @@ def create_instruction( # Python 3.11 remaps -def create_jump_absolute(target) -> Instruction: +def create_jump_absolute(target: Instruction) -> Instruction: inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" return create_instruction(inst, target=target) -def create_load_const(val, checked=True) -> Instruction: +def create_load_const(val: Any, checked: bool = True) -> Instruction: """ In general we should only create `LOAD_CONST` for immutable objects, but sometimes it's convenient _and safe_ for Dynamo create `LOAD_CONST` for @@ -220,7 +224,7 @@ def create_dup_top() -> Instruction: return create_instruction("DUP_TOP") -def create_rot_n(n) -> list[Instruction]: +def create_rot_n(n: int) -> list[Instruction]: """ Returns a "simple" sequence of instructions that rotates TOS to the n-th position in the stack. For Python < 3.11, returns a single ROT_* @@ -265,17 +269,18 @@ def add_push_null( In this case, instructions WILL be modified. """ if isinstance(inst_or_insts, Instruction): - insts = [inst_or_insts] + insts: list[Instruction] = [inst_or_insts] else: + assert isinstance(inst_or_insts, list) insts = inst_or_insts - def inst_has_bit_set(idx): + def inst_has_bit_set(idx: int) -> bool: assert insts[idx].arg is not None - return insts[idx].arg & 1 == 1 + return insts[idx].arg & 1 == 1 # type: ignore[operator] - def set_inst_bit(idx): + def set_inst_bit(idx: int) -> None: assert insts[idx].arg is not None - insts[idx].arg |= 1 + insts[idx].arg |= 1 # type: ignore[operator] if sys.version_info >= (3, 13): # In 3.13, NULL follows the callable @@ -312,8 +317,9 @@ def add_push_null_call_function_ex( is not set, due to an expected CALL_FUNCTION_EX instruction. """ if isinstance(inst_or_insts, Instruction): - insts = [inst_or_insts] + insts: list[Instruction] = [inst_or_insts] else: + assert isinstance(inst_or_insts, list) insts = inst_or_insts if sys.version_info < (3, 11): @@ -334,7 +340,7 @@ def add_push_null_call_function_ex( return insts -def create_call_function(nargs, push_null) -> list[Instruction]: +def create_call_function(nargs: int, push_null: bool) -> list[Instruction]: """ Creates a sequence of instructions that makes a function call. @@ -389,7 +395,7 @@ def create_call_function(nargs, push_null) -> list[Instruction]: return [create_instruction("CALL_FUNCTION", arg=nargs)] -def create_call_method(nargs) -> list[Instruction]: +def create_call_method(nargs: int) -> list[Instruction]: if sys.version_info >= (3, 12): return [create_instruction("CALL", arg=nargs)] if sys.version_info >= (3, 11): @@ -400,19 +406,19 @@ def create_call_method(nargs) -> list[Instruction]: return [create_instruction("CALL_METHOD", arg=nargs)] -def create_load_method(name) -> Instruction: +def create_load_method(name: str) -> Instruction: if sys.version_info >= (3, 12): # in 3.12, create a LOAD_ATTR instruction with the low bit set return create_instruction("LOAD_ATTR", arg=1, argval=name) return create_instruction("LOAD_METHOD", argval=name) -def create_setup_with(target) -> Instruction: +def create_setup_with(target: Instruction) -> Instruction: opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" return create_instruction(opname, target=target) -def create_swap(n) -> list[Instruction]: +def create_swap(n: int) -> list[Instruction]: if sys.version_info >= (3, 11): return [create_instruction("SWAP", arg=n)] # in Python < 3.11, SWAP is a macro that expands to multiple instructions @@ -465,7 +471,7 @@ def lnotab_writer( assert sys.version_info < (3, 10) lnotab: list[int] = [] - def update(lineno_new, byteno_new): + def update(lineno_new: int, byteno_new: int) -> None: nonlocal byteno, lineno while byteno_new != byteno or lineno_new != lineno: byte_offset = max(0, min(byteno_new - byteno, 255)) @@ -478,7 +484,9 @@ def update(lineno_new, byteno_new): return lnotab, update -def linetable_310_writer(first_lineno): +def linetable_310_writer( + first_lineno: int, +) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]: """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt @@ -490,7 +498,7 @@ def linetable_310_writer(first_lineno): lineno_delta = 0 byteno = 0 - def _update(byteno_delta, lineno_delta): + def _update(byteno_delta: int, lineno_delta: int) -> None: while byteno_delta != 0 or lineno_delta != 0: byte_offset = max(0, min(byteno_delta, 254)) line_offset = max(-127, min(lineno_delta, 127)) @@ -499,7 +507,7 @@ def _update(byteno_delta, lineno_delta): lineno_delta -= line_offset linetable.extend((byte_offset, line_offset & 0xFF)) - def update(lineno_new, byteno_new): + def update(lineno_new: int, byteno_new: int) -> None: nonlocal lineno, lineno_delta, byteno byteno_delta = byteno_new - byteno byteno = byteno_new @@ -507,7 +515,7 @@ def update(lineno_new, byteno_new): lineno_delta = lineno_new - lineno lineno = lineno_new - def end(total_bytes): + def end(total_bytes: int) -> None: _update(total_bytes - byteno, lineno_delta) return linetable, update, end @@ -528,7 +536,9 @@ def encode_varint(n: int) -> list[int]: return b -def linetable_311_writer(first_lineno: int): +def linetable_311_writer( + first_lineno: int, +) -> tuple[list[int], Callable[[Optional["dis.Positions"], int], None]]: """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/3.11/Objects/locations.md @@ -538,11 +548,11 @@ def linetable_311_writer(first_lineno: int): linetable = [] lineno = first_lineno - def update(positions: "dis.Positions", inst_size): + def update(positions: Optional["dis.Positions"], inst_size: int) -> None: nonlocal lineno lineno_new = positions.lineno if positions else None - def _update(delta, size): + def _update(delta: int, size: int) -> None: assert 0 < size <= 8 # first byte - use 13 (no column info) is positions is # malformed, otherwise use 14 (long form) @@ -721,7 +731,9 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, return bytes(code), bytes(lnotab) -def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int): +def _get_instruction_by_offset( + offset_to_inst: dict[int, Instruction], offset: int +) -> Optional[Instruction]: """ Get the instruction located at a given offset, accounting for EXTENDED_ARGs """ @@ -731,9 +743,11 @@ def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: i return None -def virtualize_jumps(instructions) -> None: +def virtualize_jumps(instructions: Iterable[Instruction]) -> None: """Replace jump targets with pointers to make editing easier""" - jump_targets = {inst.offset: inst for inst in instructions} + jump_targets = { + inst.offset: inst for inst in instructions if inst.offset is not None + } for inst in instructions: if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: @@ -756,7 +770,7 @@ def flip_jump_direction(instruction: Instruction) -> None: assert instruction.opcode in _REL_JUMPS -def _get_instruction_front(instructions: list[Instruction], idx: int): +def _get_instruction_front(instructions: list[Instruction], idx: int) -> Instruction: """ i.e. get the first EXTENDED_ARG instruction (if any) when targeting instructions[idx] with a jump. @@ -770,7 +784,7 @@ def _get_instruction_front(instructions: list[Instruction], idx: int): return target -def devirtualize_jumps(instructions): +def devirtualize_jumps(instructions: list[Instruction]) -> None: """Fill in args for virtualized jump target after instructions may have moved""" jumps = set(dis.hasjabs).union(set(dis.hasjrel)) @@ -778,6 +792,11 @@ def devirtualize_jumps(instructions): for inst in instructions: if inst.opcode in jumps: if inst.opcode not in dis.hasjabs: + assert ( + inst.target is not None + and inst.target.offset is not None + and inst.offset is not None + ) if inst.target.offset < inst.offset: if sys.version_info < (3, 11): raise RuntimeError("Got negative jump offset for Python < 3.11") @@ -796,6 +815,7 @@ def devirtualize_jumps(instructions): # compute jump instruction arg for inst in instructions: if inst.opcode in jumps: + assert inst.target is not None target = _get_instruction_front(instructions, indexof[inst.target]) if inst.opcode in dis.hasjabs: if sys.version_info < (3, 10): @@ -808,6 +828,7 @@ def devirtualize_jumps(instructions): raise RuntimeError("Python 3.11+ should not have absolute jumps") else: # relative jump # byte offset between target and next instruction + assert target.offset is not None and inst.offset is not None inst.arg = abs( int(target.offset - inst.offset - instruction_size(inst)) ) @@ -818,7 +839,9 @@ def devirtualize_jumps(instructions): inst.argrepr = f"to {target.offset}" -def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]): +def virtualize_exception_table( + exn_tab_bytes: bytes, instructions: list[Instruction] +) -> None: """Replace exception table entries with pointers to make editing easier""" exn_tab = parse_exception_table(exn_tab_bytes) offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} @@ -827,7 +850,7 @@ def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruct exn_tab_iter = iter(exn_tab) try: - def step(): + def step() -> tuple[ExceptionTableEntry, InstructionExnTabEntry]: nonlocal end_offset_idx entry = next(exn_tab_iter) # find rightmost offset <= entry.end, since entry.end may not be @@ -841,9 +864,9 @@ def step(): assert end_offset_idx > 0 end_offset = offsets[end_offset_idx - 1] inst_entry = InstructionExnTabEntry( - _get_instruction_by_offset(offset_to_inst, entry.start), - _get_instruction_by_offset(offset_to_inst, end_offset), - _get_instruction_by_offset(offset_to_inst, entry.target), + _get_instruction_by_offset(offset_to_inst, entry.start), # type: ignore[arg-type] + _get_instruction_by_offset(offset_to_inst, end_offset), # type: ignore[arg-type] + _get_instruction_by_offset(offset_to_inst, entry.target), # type: ignore[arg-type] entry.depth, entry.lasti, ) @@ -851,6 +874,7 @@ def step(): entry, inst_entry = step() for inst in instructions: + assert inst.offset is not None while inst.offset > entry.end: entry, inst_entry = step() if inst.offset >= entry.start: @@ -872,15 +896,18 @@ def compute_exception_table( start = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.start] ).offset + assert start is not None # point to the last 2 bytes of the end instruction end = ( cast(int, inst.exn_tab_entry.end.offset) + instruction_size(inst.exn_tab_entry.end) - 2 ) + assert end is not None target = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.target] ).offset + assert target is not None key = (start, end) val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) if key in exn_dict: @@ -900,7 +927,7 @@ def compute_exception_table( key_stack: list[tuple[int, int]] = [] exn_tab: list[ExceptionTableEntry] = [] - def pop(): + def pop() -> None: """ Pop the key_stack and append an exception table entry if possible. """ @@ -934,7 +961,7 @@ def pop(): def check_inst_exn_tab_entries_nested( - tab: list[InstructionExnTabEntry], indexof + tab: list[InstructionExnTabEntry], indexof: dict[Instruction, int] ) -> None: """ Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, @@ -979,7 +1006,7 @@ def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None: instructions[i].exn_tab_entry = copy.copy(entry) -def check_inst_exn_tab_entries_valid(instructions: list[Instruction]): +def check_inst_exn_tab_entries_valid(instructions: list[Instruction]) -> None: """ Checks that exn_tab_entries of instructions are valid. An entry's start, end, and target must be in instructions. @@ -1012,7 +1039,9 @@ def strip_extended_args(instructions: list[Instruction]) -> None: # instruction, exception table entries, and positions. # Returns the modified sequence of instructions (including the modified # old instruction!) that can be manipulated elsewhere. -def overwrite_instruction(old_inst, new_insts): +def overwrite_instruction( + old_inst: Instruction, new_insts: list[Instruction] +) -> list[Instruction]: # update old_inst.exnt_tab_entry.end if necessary if ( old_inst.exn_tab_entry @@ -1161,7 +1190,7 @@ def fix_extended_args(instructions: list[Instruction]) -> int: """Fill in correct argvals for EXTENDED_ARG ops""" output: list[Instruction] = [] - def maybe_pop_n(n): + def maybe_pop_n(n: int) -> None: for _ in range(n): if output and output[-1].opcode == dis.EXTENDED_ARG: output.pop() @@ -1190,7 +1219,7 @@ def maybe_pop_n(n): return added -def instruction_size(inst) -> int: +def instruction_size(inst: Instruction) -> int: import torch if sys.version_info >= (3, 11): @@ -1198,21 +1227,21 @@ def instruction_size(inst) -> int: return 2 -def check_offsets(instructions) -> None: +def check_offsets(instructions: Sequence[Instruction]) -> None: offset = 0 for inst in instructions: assert inst.offset == offset offset += instruction_size(inst) -def update_offsets(instructions) -> None: +def update_offsets(instructions: Sequence[Instruction]) -> None: offset = 0 for inst in instructions: inst.offset = offset offset += instruction_size(inst) -def debug_bytes(*args) -> str: +def debug_bytes(*args: bytes) -> str: index = range(max(map(len, args))) result = [ " ".join(f"{x:03}" for x in arg) @@ -1224,7 +1253,7 @@ def debug_bytes(*args) -> str: return "bytes mismatch\n" + "\n".join(result) -def debug_checks(code): +def debug_checks(code: types.CodeType) -> None: """Make sure our assembler produces same bytes as we start with""" dode = transform_code_object(code, lambda x, y: None, safe=True) assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) @@ -1237,7 +1266,7 @@ def debug_checks(code): HAS_CONST = set(dis.hasconst) -def get_const_index(code_options, val) -> int: +def get_const_index(code_options: dict[str, Any], val: Any) -> int: for i, v in enumerate(code_options["co_consts"]): # NOTE: stronger comparison is required, since we have # examples where two values compare equal but have @@ -1249,11 +1278,15 @@ def get_const_index(code_options, val) -> int: return len(code_options["co_consts"]) - 1 -def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None): +def fix_vars( + instructions: list[Instruction], + code_options: dict[str, Any], + varname_from_oparg: Optional[Callable[..., Any]] = None, +) -> None: # compute instruction arg from argval if arg is not provided names = {name: idx for idx, name in enumerate(code_options["co_names"])} - def get_name_index(name) -> int: + def get_name_index(name: str) -> int: try: idx = names[name] except KeyError: @@ -1288,7 +1321,7 @@ def get_name_index(name) -> int: } for i in range(len(instructions)): - def should_compute_arg(): + def should_compute_arg() -> bool: # argval is prioritized over arg return instructions[i].argval is not _NotProvided @@ -1356,7 +1389,7 @@ def should_compute_arg(): instructions[i].arg = idx -def clear_instruction_args(instructions): +def clear_instruction_args(instructions: list[Instruction]) -> None: # Clear the instruction arg for instructions that have argvals. # Useful for using dis'd bytecode within generated bytecode. for inst in instructions: @@ -1413,7 +1446,11 @@ def get_code_keys() -> list[str]: return keys -def transform_code_object(code, transformations, safe=False) -> types.CodeType: +def transform_code_object( + code: types.CodeType, + transformations: Callable[[list[Instruction], dict[str, Any]], Any], + safe: bool = False, +) -> types.CodeType: keys = get_code_keys() code_options = {k: getattr(code, k) for k in keys} assert len(code_options["co_varnames"]) == code_options["co_nlocals"] @@ -1466,7 +1503,7 @@ def clean_and_assemble_instructions( return instructions, types.CodeType(*[code_options[k] for k in keys]) -def populate_kw_names_argval(instructions, consts): +def populate_kw_names_argval(instructions: Sequence[Instruction], consts: Any) -> None: for inst in instructions: if inst.opname == "KW_NAMES": inst.argval = consts[inst.arg] @@ -1474,7 +1511,7 @@ def populate_kw_names_argval(instructions, consts): # If safe=True, we do not make any bytecode modifications. # Mainly used for debugging bytecode_transformation (see debug_checks) -def cleaned_instructions(code, safe=False) -> list[Instruction]: +def cleaned_instructions(code: types.CodeType, safe: bool = False) -> list[Instruction]: instructions = _cached_cleaned_instructions(code, safe) # We have a lot of code that implicitly mutates the instruction array. We # could do better here by making the copies explicit when necessary. @@ -1482,7 +1519,7 @@ def cleaned_instructions(code, safe=False) -> list[Instruction]: # Copy an instructions array, making sure to remap the individual instruction targets. -def _clone_instructions(instructions): +def _clone_instructions(instructions: Sequence[Instruction]) -> list[Instruction]: # This is super hot and this is the fastest way to do this (tried copy.copy # and dataclasses.replace). copied = [ @@ -1504,10 +1541,10 @@ def _clone_instructions(instructions): remap = dict(zip(instructions, copied)) # Handle `None` in the remapper so we don't need an extra `if`. - remap[None] = None + remap[None] = None # type: ignore[index, assignment] for i in copied: - i.target = remap[i.target] + i.target = remap[i.target] # type: ignore[index] if entry := i.exn_tab_entry: i.exn_tab_entry = InstructionExnTabEntry( remap[entry.start], @@ -1520,7 +1557,9 @@ def _clone_instructions(instructions): @functools.lru_cache -def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: +def _cached_cleaned_instructions( + code: types.CodeType, safe: bool = False +) -> Sequence[Instruction]: instructions = list(map(convert_instruction, dis.get_instructions(code))) check_offsets(instructions) if sys.version_info >= (3, 11): @@ -1548,7 +1587,7 @@ def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: _unique_id_counter = itertools.count() -def unique_id(name, with_uuid=False) -> str: +def unique_id(name: str, with_uuid: bool = False) -> str: ret = f"{name}_{next(_unique_id_counter)}" if with_uuid: ret += f"_{uuid.uuid4()}".replace("-", "_") @@ -1560,7 +1599,12 @@ def is_generator(code: types.CodeType) -> bool: return (code.co_flags & co_generator) > 0 -def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): +def bytecode_from_template( + fn: Callable[..., Any], + varname_map: Optional[Mapping[Any, Any]] = None, + noreturn: bool = True, + noprefix: bool = True, +) -> list[Instruction]: """Generates bytecode from a template function `fn` for use in dynamo bytecode generation. diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index beaaa77671e1c..28f63c715fe52 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -78,7 +78,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None): ), (None, None), ) - assert dummy_idx is not None + assert dummy_idx is not None and dummy_inst is not None # replace LOAD_FAST dummy with first NOP marking exception area overwrite_instruction(dummy_inst, [create_instruction("NOP")]) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 769fb510fdf18..b7ba37b08d35f 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -943,6 +943,7 @@ def handle_graph_break( self.output.add_output_instructions( [create_instruction("KW_NAMES", argval=kw_names)] ) + assert inst.arg is not None call_insts = create_call_function(inst.arg, False) call_insts[-1].copy_positions(inst) self.output.add_output_instructions(call_insts) From 5951fcd50acc51bb91beae8488758f35219da849 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 16 Jul 2025 22:08:57 +0000 Subject: [PATCH 141/457] [Dynamo][Better Engineering] Support typing in codegen.py (#158386) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical tracing point for dynamo, primarily for `codegen.py` but also `config.py` Running ``` mypy torch/_dynamo/codegen.py torch/_dynamo/config.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 347 | 1330 | 26.09% | 24 | 50 | 48.00% | | This PR | 1334 | 1334 | 100.00% | 50 | 50 | 100.00% | | Delta | +987 | +4 | +73.91.% | +26 | 0 | +52.00% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158386 Approved by: https://github.com/StrongerXi --- torch/_dynamo/codegen.py | 118 ++++++++++++++++++-------------- torch/_dynamo/config.py | 6 +- torch/_dynamo/external_utils.py | 2 - torch/_dynamo/side_effects.py | 1 + 4 files changed, 69 insertions(+), 58 deletions(-) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 946ad280570ab..f64ef6e5231af 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for generating Python bytecode in PyTorch's Dynamo system. It includes functionality for: @@ -18,7 +16,8 @@ import sys import types from collections import Counter -from typing import Optional, TYPE_CHECKING, Union +from collections.abc import Iterable +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch.nn from torch.utils._ordered_set import OrderedSet @@ -55,6 +54,8 @@ if TYPE_CHECKING: + from torch._dynamo.variables.builder import GraphArg + from .symbolic_convert import InstructionTranslatorBase @@ -74,8 +75,8 @@ def __init__( tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, - tempvars=None, - overridden_sources=None, + tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None, + overridden_sources: Optional[dict[Source, Source]] = None, ) -> None: self.root = root self.top_of_stack: Optional[Union[VariableTracker, Source]] = None @@ -86,7 +87,7 @@ def __init__( # locals, and maps the VariableTracker/Source to the local variable # name. Note that it could map to None initially, in which case we'll # overwrite it to map to real temporary names via `add_cache`. - self.tempvars = tempvars or {} + self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {} self.tx = tx self.graph_output_var = graph_output_var self.code_options = self.tx.output.code_options @@ -98,7 +99,9 @@ def __init__( # without affecting other components, e.g., guards. self.overridden_sources: dict[Source, Source] = overridden_sources or {} - def restore_stack(self, stack_values, *, value_from_source=True): + def restore_stack( + self, stack_values: list[Any], *, value_from_source: bool = True + ) -> None: prev = self.value_from_source self.value_from_source &= value_from_source try: @@ -106,14 +109,18 @@ def restore_stack(self, stack_values, *, value_from_source=True): finally: self.value_from_source = prev - def graph_output_vars(self): + def graph_output_vars(self) -> list[VariableTracker]: return [x.variable for x in self.graph_outputs.values()] - def call_reconstruct(self, value): + def call_reconstruct( + self, value: Union[VariableTracker, Source, "GraphArg"] + ) -> None: res = value.reconstruct(self) assert res is None, f"reconstruct!=None {value}" - def add_push_null(self, gen_fn, call_function_ex=False): + def add_push_null( + self, gen_fn: Callable[[], None], call_function_ex: bool = False + ) -> None: """ `gen_fn` generates instructions via PyCodegen methods that push a single callable to the stack. @@ -142,7 +149,9 @@ def add_push_null(self, gen_fn, call_function_ex=False): # NULL will be at top of stack self.clear_tos() - def __call__(self, value, allow_cache=True): + def __call__( + self, value: Union[VariableTracker, Source], allow_cache: bool = True + ) -> None: """ Generate code such that top-of-stack (TOS) is set to value. @@ -297,7 +306,7 @@ def __call__(self, value, allow_cache=True): value.as_tensor(self.tx, torch.float64) ) - def gen_fn(): + def gen_fn() -> None: self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -322,7 +331,7 @@ def gen_fn(): output.extend(create_call_function(1, False)) elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: - def gen_fn(): + def gen_fn() -> None: self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -363,7 +372,7 @@ def gen_fn(): self.top_of_stack = value - def add_graph_output(self, value): + def add_graph_output(self, value: VariableTracker) -> int: graph_outputs_key = id(value.as_proxy()) if graph_outputs_key not in self.graph_outputs: self.graph_outputs[graph_outputs_key] = GraphOutputEntry( @@ -371,25 +380,26 @@ def add_graph_output(self, value): ) return graph_outputs_key - def load_graph_output(self, index): + def load_graph_output(self, index: int) -> None: output = self._output + assert self.graph_output_var is not None output.append(self.create_load(self.graph_output_var)) output.append(self.create_load_const(index)) output.append(self.create_binary_subscr()) - def add_cache(self, value): + def add_cache(self, value: Union[VariableTracker, Source]) -> None: var = self.new_var() self.tempvars[value] = var self._output.append(self.create_store(var)) - def foreach(self, items): + def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None: for i in items: self(i) def create_binary_subscr(self) -> Instruction: return create_instruction("BINARY_SUBSCR") - def setup_globally_cached(self, name, value): + def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]: """Store value in a new global""" name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) f_globals = self.tx.f_globals @@ -399,15 +409,15 @@ def setup_globally_cached(self, name, value): f_globals[name] = value return [self.create_load_global(name, add=True)] - def clear_tos(self): + def clear_tos(self) -> None: self.top_of_stack = None - def append_output(self, inst): + def append_output(self, inst: Instruction) -> None: assert isinstance(inst, Instruction) self._output.append(inst) self.clear_tos() - def extend_output(self, insts): + def extend_output(self, insts: list[Instruction]) -> None: assert all(isinstance(x, Instruction) for x in insts) self._output.extend(insts) self.clear_tos() @@ -415,66 +425,68 @@ def extend_output(self, insts): def get_instructions(self) -> list[Instruction]: return self._output - def create_load(self, name) -> Instruction: + def create_load(self, name: str) -> Instruction: assert name in self.code_options["co_varnames"], f"{name} missing" return create_instruction("LOAD_FAST", argval=name) - def create_load_closure(self, name) -> Instruction: + def create_load_closure(self, name: str) -> Instruction: assert name in self.cell_and_freevars() inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" return create_instruction(inst_name, argval=name) - def create_load_deref(self, name) -> Instruction: + def create_load_deref(self, name: str) -> Instruction: assert name in self.cell_and_freevars() return create_instruction("LOAD_DEREF", argval=name) - def create_store(self, name) -> Instruction: + def create_store(self, name: str) -> Instruction: assert name in self.code_options["co_varnames"], f"{name} missing" return create_instruction("STORE_FAST", argval=name) - def create_store_deref(self, name) -> Instruction: + def create_store_deref(self, name: str) -> Instruction: assert name in self.cell_and_freevars() return create_instruction("STORE_DEREF", argval=name) - def create_load_global(self, name, add=False) -> Instruction: + def create_load_global(self, name: str, add: bool = False) -> Instruction: if add: self.tx.output.update_co_names(name) assert name in self.code_options["co_names"], f"{name} not in co_names" return create_instruction("LOAD_GLOBAL", argval=name) - def create_load_const(self, value) -> Instruction: + def create_load_const(self, value: Any) -> Instruction: return create_load_const(value) - def create_load_const_unchecked(self, value) -> Instruction: + def create_load_const_unchecked(self, value: Any) -> Instruction: return create_load_const(value, checked=False) - def load_method(self, name): + def load_method(self, name: str) -> None: self.tx.output.update_co_names(name) self.append_output(create_load_method(name)) - def call_method(self, nargs): + def call_method(self, nargs: int) -> None: self.extend_output(create_call_method(nargs)) - def create_load_attr(self, name) -> Instruction: + def create_load_attr(self, name: str) -> Instruction: if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("LOAD_ATTR", argval=name) - def load_attr(self, name): + def load_attr(self, name: str) -> None: self.append_output(self.create_load_attr(name)) - def create_load_attrs(self, names): + def create_load_attrs(self, names: str) -> list[Instruction]: return [self.create_load_attr(name) for name in names.split(".")] - def create_store_attr(self, name) -> Instruction: + def create_store_attr(self, name: str) -> Instruction: if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("STORE_ATTR", argval=name) - def store_attr(self, name): + def store_attr(self, name: str) -> None: self.append_output(self.create_store_attr(name)) - def load_function_name(self, fn_name, push_null, num_on_stack=0): + def load_function_name( + self, fn_name: str, push_null: bool, num_on_stack: int = 0 + ) -> list[Instruction]: """Load the global fn_name on the stack num_on_stack down""" output = [] if push_null and sys.version_info >= (3, 11): @@ -495,7 +507,7 @@ def load_function_name(self, fn_name, push_null, num_on_stack=0): ) return output - def rot_n(self, n): + def rot_n(self, n: int) -> list[Instruction]: try: return create_rot_n(n) except AttributeError: @@ -508,29 +520,29 @@ def rot_n(self, n): create_instruction("UNPACK_SEQUENCE", arg=n), ] - def pop_top(self): + def pop_top(self) -> None: self.append_output(create_instruction("POP_TOP")) - def call_function(self, nargs: int, push_null: bool): + def call_function(self, nargs: int, push_null: bool) -> None: self.extend_output(create_call_function(nargs, push_null=push_null)) - def dup_top(self): + def dup_top(self) -> None: self.append_output(create_dup_top()) - def store(self, varname): + def store(self, varname: str) -> None: self.append_output(self.create_store(varname)) - def load_deref(self, varname): + def load_deref(self, varname: str) -> None: self.append_output(self.create_load_deref(varname)) def make_function_with_closure( - self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 - ): + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack: int = 0 + ) -> None: freevars = code.co_freevars assert freevars output = self._output - def gen_fn(): + def gen_fn() -> None: # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars` # requires that in the generated bytecode, these cells would keep # their original local names, which we ensure via @@ -561,7 +573,7 @@ def gen_fn(): output.extend(self.rot_n(num_on_stack + 1)) self.clear_tos() - def create_load_python_module(self, mod) -> Instruction: + def create_load_python_module(self, mod: types.ModuleType) -> Instruction: """ Generate a LOAD_GLOBAL instruction to fetch a given python module. """ @@ -589,7 +601,7 @@ def make_call_generated_code(self, fn_name: str) -> None: seen_sources: OrderedSet[Source] = OrderedSet() - def collect_temp_source(source): + def collect_temp_source(source: Source) -> None: if source in seen_sources: # This source is used at least twice, so it can be reused self.mark_source_temp(source) @@ -655,10 +667,10 @@ def collect_temp_source(source): self.extend_output(create_call_function(len(graphargs), False)) - def create_import_name(self, module_name) -> Instruction: + def create_import_name(self, module_name: str) -> Instruction: return create_instruction("IMPORT_NAME", argval=module_name) - def load_import_from(self, module_name, object_name) -> None: + def load_import_from(self, module_name: str, object_name: str) -> None: source = AttrSource(self.tx.import_source(module_name), object_name) # Note: This approach is somewhat aggressive because typically, a source is marked # as a tempvar only when it is used more than once. In this case, we're marking it @@ -667,7 +679,9 @@ def load_import_from(self, module_name, object_name) -> None: self.mark_source_temp(source) self(source) - def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]: + def create_call_function_kw( + self, nargs: int, kw_names: Iterable[str], push_null: bool + ) -> list[Instruction]: if sys.version_info >= (3, 13): output = create_call_function(nargs, push_null) assert output[-1].opname == "CALL" @@ -691,5 +705,5 @@ def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instructio create_instruction("CALL_FUNCTION_KW", arg=nargs), ] - def create_delete(self, value) -> Instruction: + def create_delete(self, value: object) -> Instruction: return create_instruction("DELETE_FAST", argval=value) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 21598f71bced5..c7f0fb4adeb1f 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Configuration module for TorchDynamo compiler and optimization settings. @@ -450,7 +448,7 @@ record_compile_time_instruction_count = False -def default_debug_dir_root(): +def default_debug_dir_root() -> str: # [@compile_ignored: debug] DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" if DEBUG_DIR_VAR_NAME in os.environ: @@ -629,7 +627,7 @@ def default_debug_dir_root(): if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 - def _make_closure_patcher(**changes): ... + def _make_closure_patcher(**changes: Any) -> Any: ... install_config_module(sys.modules[__name__]) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index c4fbc62ea5db2..f48c14862ac04 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,5 +1,3 @@ -# This module contains functions that *will be allowed* by dynamo - """ This module contains utility functions that are explicitly allowed to be called during TorchDynamo compilation. These functions are carefully vetted to ensure they work diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index ab7c7561a88c8..8e3c4cd30145c 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -702,6 +702,7 @@ def load_new_method(): cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "object_new") ) + assert var.mutation_type.cls_source is not None cg(var.mutation_type.cls_source) # Generate the args to the __new__ method From 3cb11877aa30c04be7ffa9b4ca1722f1270a5828 Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 16 Jul 2025 10:00:29 -0700 Subject: [PATCH 142/457] [aoti][mps] Enable test_aot_inductor.py tests (#155598) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155598 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 276 ++++++++++++++++++----- test/inductor/test_aot_inductor_utils.py | 2 + test/run_test.py | 1 + 3 files changed, 228 insertions(+), 51 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 2ffcc0ead4954..9521f1defa0bd 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -55,9 +55,12 @@ IS_FBCODE, IS_MACOS, IS_WINDOWS, + MACOS_VERSION, parametrize, + skipIfMPS, skipIfRocm, skipIfXpu, + TEST_MPS, TEST_WITH_ROCM, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut @@ -173,7 +176,9 @@ def forward(self, x, y): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) - if self.device == GPU_TYPE: + if self.device == "mps": + FileCheck().check("getKernelFunction(").run(code) + elif self.device == GPU_TYPE: FileCheck().check("launchKernel(").run(code) if config.aot_inductor.embed_kernel_binary: # Not expect to see launchKernel("CUBIN_FILE_NAME" @@ -188,6 +193,7 @@ def forward(self, x, y): IS_FBCODE, "toolchain doesn't support ptx to fatbin", ) + @skipIfMPS @skipIfRocm # Skip embed_kernel_binary == True for now as it shows random # failure on CI @@ -432,6 +438,10 @@ def forward(self, y): ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True} ) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "Compilation error", + ) def test_aot_inductor_consts_cpp_build(self): class Model(torch.nn.Module): def __init__(self, device) -> None: @@ -788,6 +798,10 @@ def forward(self, a, b): inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check_model(M(), inp) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "MPS BFloat16 is only supported on MacOS 14+", + ) def test_empty_cat_dtype_promotion(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1511,6 +1525,10 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "bfloat16 is only supported on MacOS 14+", + ) def test_size_with_unbacked_add_expr(self): # Tests AOTI autotuning to make sure the correct input tensor sizes # are generated for sizes that include an expr such as s0 + u0. @@ -1766,7 +1784,7 @@ def forward(self, x): Foo(user_float_feature_idx, self.device), example_inputs, strict=False ).run_decompositions() gm = ep.module() - self.check_model(gm, example_inputs) + self.check_model(gm.to(self.device), example_inputs) def test_large_grid(self): if self.device != GPU_TYPE: @@ -2434,6 +2452,7 @@ def forward(self, x): self.check_model(converted_model, example_inputs) + @skipIfMPS def test_fallback_mem_leak_fix(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2478,6 +2497,7 @@ def forward(self, x, y, idx): torch.testing.assert_close(actual, expected) @requires_multigpu() + @skipIfMPS def test_replicate_on_devices(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2517,6 +2537,7 @@ def forward(self, x, y): self.assertTrue(same(result_cpu, result_gpu.cpu())) @requires_multigpu() + @skipIfMPS def test_on_gpu_device1(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2666,7 +2687,11 @@ def forward(self, x, y): model, example_inputs, atol=1e-4, rtol=1e-4 ) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py - if self.device == GPU_TYPE: + if self.device == "mps": + self.code_check_count( + model, example_inputs, '.getKernelFunction("generated_kernel")', 1 + ) + elif self.device == GPU_TYPE: self.code_check_count( model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1 ) @@ -3065,10 +3090,9 @@ def forward(self, x): # Call eval() here so that batch_norm won't update the running stats # Use float64 to avoid numeric difference failure - model = Model().to(device=self.device, dtype=torch.float64).eval() - example_inputs = ( - torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64), - ) + dtype = torch.float32 if self.device == "mps" else torch.float64 + model = Model().to(device=self.device, dtype=dtype).eval() + example_inputs = (torch.randn(4, 3, 64, 64, device=self.device, dtype=dtype),) self.check_model(model, example_inputs) def test_triton_next_power_of_2(self): @@ -3129,6 +3153,7 @@ def forward(self, a, b, ranks): torch._dynamo.mark_dynamic(example_inputs[1], 0) self.check_model(Model(), example_inputs) + @skipIfMPS @common_utils.parametrize("grid_type", [1, 2, 3]) @common_utils.parametrize("num_dims", [1, 2]) @common_utils.parametrize("dynamic", [False, True]) @@ -4160,6 +4185,7 @@ def forward(self, x, y): expected = Model()(*example_inputs) torch.testing.assert_close(actual, expected) + @skipIfMPS @torch._dynamo.config.patch(capture_scalar_outputs=True) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("autotuning", [False, True]) @@ -4336,24 +4362,13 @@ def forward(self, x, i1, i2, y): @patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"}) def test_runtime_checks(self): class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() + def forward(self, inputs): + return list(inputs.values()) - if SM80OrLater: - - def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): - return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) - - else: - - def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): - return (x0, x1, x2, x4, x5, x6, x7, x8, x9) - - inputs = [] + inputs = {} dtypes = [ torch.float16, torch.float32, - torch.float64, torch.bool, torch.int8, torch.int16, @@ -4361,60 +4376,75 @@ def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): torch.int64, torch.uint8, ] + + if not TEST_MPS: + dtypes.append(torch.float64) if SM80OrLater: dtypes.append(torch.bfloat16) + for dtype in dtypes: - inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device)) + inputs[f"x_{str(dtype)}"] = torch.ones( + 4, 8, 10, dtype=dtype, device=self.device + ) dim0 = Dim("s0", min=2, max=1024) dim1 = Dim("s1", min=2, max=512) dim2 = Dim("s2", min=2, max=128) dynamic_shapes = { - "x0": {0: dim0}, - "x1": {0: dim0}, - "x2": {0: dim0}, - "x4": {1: dim1}, - "x5": {1: dim1}, - "x6": {}, - "x7": {2: dim2}, - "x8": {2: dim2}, - "x9": {2: dim2}, + "x_torch.float16": {0: dim0}, + "x_torch.float32": {0: dim0}, + "x_torch.bool": {1: dim1}, + "x_torch.int8": {1: dim1}, + "x_torch.int16": {}, + "x_torch.int32": {2: dim2}, + "x_torch.int64": {2: dim2}, + "x_torch.uint8": {2: dim2}, } + if not TEST_MPS: + dynamic_shapes["x_torch.float64"] = {0: dim0} if SM80OrLater: - dynamic_shapes["x3"] = {1: dim1} + dynamic_shapes["x_torch.bfloat16"] = {1: dim1} m = Model() - inputs = tuple(inputs) + inputs = (inputs,) + dynamic_shapes = (dynamic_shapes,) with torch.no_grad(): so_path = AOTIRunnerUtil.legacy_compile( m, inputs, dynamic_shapes=dynamic_shapes ) + + # Expected results for the following checks: + # ("unmatched dtype", "unmatched dim value at", "dim value is too", "unmatched stride value at") + if SM80OrLater: + # 10 dynamic dims + expected_results = (10, 21, 18, 21) + elif TEST_MPS: + # 8 dynamic dims + expected_results = (8, 17, 14, 16) + else: + # 9 dynamic dims + expected_results = (9, 19, 16, 19) + with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: src_code = cpp.read() FileCheck().check_count( "unmatched dtype", - 10 if SM80OrLater else 9, + expected_results[0], exactly=True, ).run(src_code) FileCheck().check_count( "unmatched dim value at", - 21 - if SM80OrLater - else 19, # we have 9 dynamic dims for which we generate different checks + expected_results[1], exactly=True, ).run(src_code) FileCheck().check_count( "dim value is too", - 18 - if SM80OrLater - else 16, # we have 9 dynamic dims for which we generate two checks + expected_results[2], exactly=True, ).run(src_code) FileCheck().check_count( "unmatched stride value at", - 21 - if SM80OrLater - else 19, # we have 9 symbolic strides for which we don't generate checks + expected_results[3], exactly=True, ).run(src_code) @@ -4678,6 +4708,10 @@ def forward(self, w, i, o): ) self.check_model(Model(), example_inputs) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "FFT operations are only supported on MacOS 14+", + ) def test_fft_c2c(self): class Model(torch.nn.Module): def forward(self, x): @@ -4844,16 +4878,15 @@ def forward(self, a): a = torch.randn(batch, M, K, device=self.device) example_inputs = (a,) - kernel_calls = ( - [ + if self.device == "mps": + kernel_calls = [("aoti_torch_mps_addmm_out", 2)] + elif self.device == GPU_TYPE: + kernel_calls = [ ("triton_poi_fused_0", 1), (f"aoti_torch_{GPU_TYPE}_addmm_out", 2), ] - if self.device == GPU_TYPE - else [ - ("aoti_torch_cpu_addmm_out", 2), - ] - ) + else: + kernel_calls = [("aoti_torch_cpu_addmm_out", 2)] # test default debug printing all tensor values codegen with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): @@ -6014,13 +6047,17 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "FFT operations are only supported on MacOS 14+", + ) def test_stft(self): N_FFT = 400 HOP_LENGTH = 160 class Model(torch.nn.Module): def forward(self, x): - window = torch.hann_window(N_FFT).to(x.device) + window = torch.hann_window(N_FFT, device=x.device) stft = torch.stft( x, N_FFT, HOP_LENGTH, window=window, return_complex=True ) @@ -6104,6 +6141,7 @@ def forward(self, x, y): ) self.check_model(Model(), example_inputs) + @skipIfMPS @skipIfXpu( msg="aten::convert_weight_to_int4pack is not currently implemented for XPU" ) @@ -6643,6 +6681,13 @@ def fail_cpu(is_skip=False): ) +def fail_mps(is_skip=False): + return TestFailure( + ("mps",), + is_skip=is_skip, + ) + + def fail_gpu(suffixes: tuple[str, ...], is_skip=False): return TestFailure( suffixes, @@ -6667,6 +6712,115 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_fft_c2c": fail_gpu(("xpu",), is_skip=True), } +MPS_TEST_FAILURES = { + # Expected supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool + "test_index_put_fallback": fail_mps(), + # aten::_embedding_bag is not currently implemented for the MPS device. + "test_embedding_bag": fail_mps(), + # aten::_embedding_bag is not currently implemented for the MPS device. + "test_misc_1_max_autotune_False": fail_mps(), + "test_misc_1_max_autotune_True": fail_mps(), + # aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device. + "test_scaled_dot_product_efficient_attention": fail_mps(), + # aten::_int_mm is not implemented for MPS backend + "test__int_mm": fail_mps(), + # MPS doesn't support float64 + "test_while_loop_with_conv_dynamic_True": fail_mps(), + "test_while_loop_with_conv_dynamic_False": fail_mps(), + # MPS doesn't support float8 + "test_fp8": fail_mps(), + "test_fp8_view_of_param": fail_mps(), + # Compilation Error + "test_fallback_kernel_with_symexpr_output": fail_mps(), + "test_while_loop_with_mixed_device": fail_mps(), + "test_while_loop_nested": fail_mps(), + "test_assert_async": fail_mps(), + "test_index_put_with_none_index": fail_mps(), + "test_size_from_multi_ouptut": fail_mps(), + "test_simple_embed_kernel_binary_False": fail_mps(), + "test_while_loop_with_mixed_device_dynamic_False": fail_mps(), + "test_while_loop_with_mixed_device_dynamic_True": fail_mps(), + "test_simple_embed_cubin_False": fail_mps(is_skip=True), + "test_simple_embed_cubin_True": fail_mps(is_skip=True), + "test_simple_embed_kernel_binary_True": fail_mps(), + "test_missing_cubin": fail_mps(), + # Dynamism + "test_shifted_constraint_ranges": fail_mps(), + "test_while_loop_with_sym_expr_cond_dynamic_True": fail_mps(), + "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_mps(), + "test_cond_mismatched_branch_output_dynamic_True": fail_mps(), + "test_cond_unbacked_symint_closure_dynamic_True": fail_mps(), + "test_cond_non_tensor_predicates_dynamic_True": fail_mps(), + "test_zero_grid_with_unbacked_symbols": fail_mps(), + "test_reuse_kernel_dynamic": fail_mps(is_skip=True), + "test_while_loop_with_parameters": fail_mps(is_skip=True), + "test_cond_with_parameters": fail_mps(is_skip=True), + "test_cond_share_predicte": fail_mps(is_skip=True), + # SetStorage incorrect + "test_small_constant": fail_mps(is_skip=True), + "test_free_inactive_buffer": fail_mps(is_skip=True), + "test_extract_constants_map": fail_mps(is_skip=True), + "test_linear_freezing": fail_mps(is_skip=True), + "test_model_modified_weights": fail_mps(is_skip=True), + # Error device may not be nil + "test_zero_size_weight": fail_mps(is_skip=True), + # Constants update (segfault) + "test_update_inactive_constant_buffer": fail_mps(is_skip=True), + "test_update_constant_buffer": fail_mps(is_skip=True), + "test_so_without_weight": fail_mps(is_skip=True), + "test_constant_folding_with_update": fail_mps(is_skip=True), + "test_nested_tensor_from_jagged": fail_mps(is_skip=True), + "test_issue_140766": fail_mps(is_skip=True), + "test_buffer_mutation_and_force_mmap_weights": fail_mps(is_skip=True), + "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), + "test_large_mmaped_weights": fail_mps(is_skip=True), + "test_subclasses": fail_mps(is_skip=True), + "test_autotune_with_constant_folding": fail_mps(is_skip=True), + # MPS doesn't support triton + "test_autotuning_args_reuse": fail_mps(), + "test_triton_autotuning": fail_mps(), + "test_triton_dynamic_launcher_grid": fail_mps(), + "test_triton_dynamic_launcher_grid_infer_from_tensor": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_False_tma_version_new": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_False_tma_version_old": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_True_tma_version_new": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_True_tma_version_old": fail_mps(), + "test_size_with_unbacked_add_expr_transitive": fail_mps(), + "test_size_with_unbacked_add_and_mul_expr": fail_mps(), + "test_triton_next_power_of_2": fail_mps(), + "test_sympy_cpp_printer_min_max_minmax0": fail_mps(), + "test_sympy_cpp_printer_min_max_minmax1": fail_mps(), + "test_triton_kernel_dynamic_shape_with_div": fail_mps(), + "test_triton_kernel_reinterpret_view": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_old_mps": fail_mps(), + "test_triton_kernel_sympy_expr_arg": fail_mps(), + "test_triton_kernel_sympy_fn_like_arg": fail_mps(), + "test_triton_kernel_with_none_input": fail_mps(), + "test_triton_kernel_equal_to_1_arg": fail_mps(), + "test_triton_kernel_with_none_inputs_and_equal_to_1_arg": fail_mps(), + "test_triton_kernel_equal_to_1_float_arg_dynamic_True": fail_mps(), + "test_triton_kernel_equal_to_1_float_arg_dynamic_False": fail_mps(), + "test_triton_kernel_weird_param_order": fail_mps(), + "test_triton_kernel_dynamic_grid": fail_mps(), + "test_repeated_user_defined_triton_kernel_embed_kernel_binary_False": fail_mps(), + "test_repeated_user_defined_triton_kernel_embed_kernel_binary_True": fail_mps(), + "test_triton_kernel_extern_kernel_arg": fail_mps(), + "test_triton_kernel_multi_output_arg": fail_mps(), + "test_triton_kernel_reinterpret_view_mem_leak": fail_mps(), + "test_triton_mutated_autotuning": fail_mps(), + "test_sym_i64_input_codegen": fail_mps(), + "test_none_args_aot_codegen": fail_mps(), + "test_aoti_debug_printer_sym_inputs": fail_mps(), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_mps(), +} + class AOTInductorTestABICompatibleCpu(TestCase): device = "cpu" @@ -6704,6 +6858,26 @@ class AOTInductorTestABICompatibleGpu(TestCase): GPU_TEST_FAILURES, ) + +@unittest.skipIf(not torch.backends.mps.is_available(), "No MPS backend available") +class AOTInductorTestABICompatibleMps(TestCase): + device = "mps" + device_type = "mps" + check_model = check_model + check_model_with_multiple_inputs = check_model_with_multiple_inputs + code_check_count = code_check_count + allow_stack_allocation = False + use_minimal_arrayref_interface = False + + +copy_tests( + AOTInductorTestsTemplate, + AOTInductorTestABICompatibleMps, + "mps", + MPS_TEST_FAILURES, +) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 9d25aa4756018..a2706933d6156 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -102,6 +102,8 @@ def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner": return torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) elif device == "xpu": return torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) + elif device == "mps": + return torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) else: return torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) diff --git a/test/run_test.py b/test/run_test.py index 26b10ac4ac61c..64d6067edc94a 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1592,6 +1592,7 @@ def get_selected_tests(options) -> list[str]: "test_nn", "inductor/test_mps_basic", "inductor/test_torchinductor", + "inductor/test_aot_inductor", ] else: # Exclude all mps tests otherwise From e311886e3d57c83a88b97a084dd0b95d6d1537a8 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 16 Jul 2025 08:39:30 -0700 Subject: [PATCH 143/457] Add transpose to torch/csrc/stable (#158160) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158160 Approved by: https://github.com/janeyx99 --- .../libtorch_agnostic/csrc/kernel.cpp | 19 +++++++++++++++++++ .../libtorch_agnostic/ops.py | 12 ++++++++++++ .../test/test_libtorch_agnostic.py | 10 ++++++++++ torch/csrc/stable/ops.h | 17 +++++++++++++++++ 4 files changed, 58 insertions(+) create mode 100644 torch/csrc/stable/ops.h diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 554203752479b..6125c21f0bedc 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include @@ -254,3 +255,21 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("is_contiguous", &boxed_is_contiguous); } + +Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { + return transpose(t, dim0, dim1); +} + +void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); + + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("my_transpose", &boxed_my_transpose); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 2b4fbd40eb1a2..4a193cc73593a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -116,3 +116,15 @@ def is_contiguous(t) -> bool: Returns: is_contiguous(t) """ return torch.ops.libtorch_agnostic.is_contiguous.default(t) + + +def my_transpose(t, dim0, dim1) -> Tensor: + """ + Returns t.transpose(dim0, dim1) + + Args: + t: Tensor + + Returns: my_transpose(t, dim0, dim1) + """ + return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index ba1d6411b0984..3d9e1ae929289 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -173,6 +173,16 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + def test_my_transpose(self, device): + import libtorch_agnostic + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_transpose(t, 0, 1) + self.assertEqual(out, torch.transpose(t, 0, 1)) + + with self.assertRaisesRegex(RuntimeError, "API call failed"): + libtorch_agnostic.ops.my_transpose(t, 1, 2) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h new file mode 100644 index 0000000000000..4105339e569c1 --- /dev/null +++ b/torch/csrc/stable/ops.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +using torch::stable::Tensor; + +// We expect this to be the stable version of the transpose op with identical +// semantics to the existing transpose.int op. +inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { + const auto num_args = 3; + std::array stack{from(self), from(dim0), from(dim1)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); + return to(stack[0]); +} From a9f902add02383ca1b0386eb865767641975fede Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Wed, 16 Jul 2025 23:14:36 +0000 Subject: [PATCH 144/457] [CUDA] Use runtime driver API for cuStreamWriteValue32 (#158295) Reopen https://github.com/pytorch/pytorch/pull/156097 Fixes https://github.com/pytorch/pytorch/issues/154073 Reference: https://github.com/NVIDIA/Fuser/pull/4197 See PR https://github.com/pytorch/pytorch/pull/156097 and https://github.com/pytorch/pytorch/pull/154097 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158295 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/eqy, https://github.com/huydhn Co-authored-by: Wei Wang --- c10/cuda/driver_api.cpp | 55 ++++++++++++++++----- c10/cuda/driver_api.h | 60 ++++++++++++++--------- test/distributed/test_symmetric_memory.py | 4 -- 3 files changed, 81 insertions(+), 38 deletions(-) diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp index bb201b5c0397f..f4b62e53fcc00 100644 --- a/c10/cuda/driver_api.cpp +++ b/c10/cuda/driver_api.cpp @@ -1,30 +1,35 @@ #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include #include #include #include +#include +#include #include namespace c10::cuda { namespace { +void* get_symbol(const char* name, int version); + DriverAPI create_driver_api() { - void* handle_0 = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_NOLOAD); - TORCH_CHECK(handle_0, "Can't open libcuda.so.1: ", dlerror()); void* handle_1 = DriverAPI::get_nvml_handle(); DriverAPI r{}; -#define LOOKUP_LIBCUDA_ENTRY(name) \ - r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \ - TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror()) - C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY) -#undef LOOKUP_LIBCUDA_ENTRY +#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED(name, version) \ + r.name##_ = reinterpret_cast(get_symbol(#name, version)); \ + TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name); + C10_LIBCUDA_DRIVER_API_REQUIRED(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED) +#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED -#define LOOKUP_LIBCUDA_ENTRY(name) \ - r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \ - dlerror(); - C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY) -#undef LOOKUP_LIBCUDA_ENTRY +// Users running drivers between 12.0 and 12.3 will not have these symbols, +// they would be resolved into nullptr, but we guard their usage at runtime +// to ensure safe fallback behavior. +#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL(name, version) \ + r.name##_ = reinterpret_cast(get_symbol(#name, version)); + C10_LIBCUDA_DRIVER_API_OPTIONAL(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL) +#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL if (handle_1) { #define LOOKUP_NVML_ENTRY(name) \ @@ -35,6 +40,32 @@ DriverAPI create_driver_api() { } return r; } + +void* get_symbol(const char* name, int version) { + void* out = nullptr; + cudaDriverEntryPointQueryResult qres{}; + + // CUDA 12.5+ supports version-based lookup +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050) + if (auto st = cudaGetDriverEntryPointByVersion( + name, &out, version, cudaEnableDefault, &qres); + st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) { + return out; + } +#endif + + // This fallback to the old API to try getting the symbol again. + if (auto st = cudaGetDriverEntryPoint(name, &out, cudaEnableDefault, &qres); + st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) { + return out; + } + + // If the symbol cannot be resolved, report and return nullptr; + // the caller is responsible for checking the pointer. + LOG(INFO) << "Failed to resolve symbol " << name; + return nullptr; +} + } // namespace void* DriverAPI::get_nvml_handle() { diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index a8ded9de68d72..9800809d1e535 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,29 +20,42 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuDeviceGetAttribute) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ - _(cuMemsetD32Async) \ - _(cuStreamWriteValue32) \ - _(cuGetErrorString) +// The integer in the second column specifies the requested CUDA Driver API +// version. The dynamic loader will accept a driver with a newer version, but it +// ensures that the requested symbol exists in *at least* the specified version +// or earlier. + +// Keep these requested versions as low as possible to maximize compatibility +// across different driver versions. + +// Why do we pin to an older version instead of using the latest? +// If a user installs a newer driver, blindly resolving the symbol may bind to a +// newer version of the function with different behavior, potentially breaking +// PyTorch. + +#define C10_LIBCUDA_DRIVER_API_REQUIRED(_) \ + _(cuDeviceGetAttribute, 12000) \ + _(cuMemAddressReserve, 12000) \ + _(cuMemRelease, 12000) \ + _(cuMemMap, 12000) \ + _(cuMemAddressFree, 12000) \ + _(cuMemSetAccess, 12000) \ + _(cuMemUnmap, 12000) \ + _(cuMemCreate, 12000) \ + _(cuMemGetAllocationGranularity, 12000) \ + _(cuMemExportToShareableHandle, 12000) \ + _(cuMemImportFromShareableHandle, 12000) \ + _(cuMemsetD32Async, 12000) \ + _(cuStreamWriteValue32, 12000) \ + _(cuGetErrorString, 12000) #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) -#define C10_LIBCUDA_DRIVER_API_12030(_) \ - _(cuMulticastAddDevice) \ - _(cuMulticastBindMem) \ - _(cuMulticastCreate) +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ + _(cuMulticastAddDevice, 12030) \ + _(cuMulticastBindMem, 12030) \ + _(cuMulticastCreate, 12030) #else -#define C10_LIBCUDA_DRIVER_API_12030(_) +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif #define C10_NVML_DRIVER_API(_) \ @@ -56,11 +69,14 @@ namespace c10::cuda { struct DriverAPI { +#define CREATE_MEMBER_VERSIONED(name, version) decltype(&name) name##_; #define CREATE_MEMBER(name) decltype(&name) name##_; - C10_LIBCUDA_DRIVER_API(CREATE_MEMBER) - C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER) + C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED) + C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED) C10_NVML_DRIVER_API(CREATE_MEMBER) +#undef CREATE_MEMBER_VERSIONED #undef CREATE_MEMBER + static DriverAPI* get(); static void* get_nvml_handle(); }; diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index f6f7fcfc38854..ed39107a0676f 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -1065,10 +1065,6 @@ class SymmMemSingleProcTest(TestCase): not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0), "stream_write_value32 currently only supports cuda version>=12.0", ) - @skipIf( - _get_torch_cuda_version() >= (12, 6), - "https://github.com/pytorch/pytorch/issues/154073", - ) @runOnRocmArch(MI300_ARCH) def test_stream_write_value32(self): tensor = torch.zeros(4, dtype=torch.uint32, device="cuda") From a4d753295ee5662056bdfd1b00fa242071ac7125 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 16 Jul 2025 23:31:06 +0000 Subject: [PATCH 145/457] [Dynamo][Better Engineering] Add enhanced typing support to `_dynamo/eval_frame.py` (#158276) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to the main entrypoint for dynamo, `eval_frame.py` Running ``` mypy torch/_dynamo/eval_frame.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 623 | 2232 | 27.91% | 19 | 68 | 27.94% | | This PR | 2285 | 2285 | 100.00% | 68 | 68 | 100.00% | | Delta | +1662 | +63 | +72.09% | +49 | 0 | +72.06% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158276 Approved by: https://github.com/williamwen42 Co-authored-by: William Wen --- torch/_C/_dynamo/eval_frame.pyi | 2 +- torch/_dynamo/decorators.py | 2 +- torch/_dynamo/eval_frame.py | 391 +++++++++++++++----------- torch/_dynamo/package.py | 2 +- torch/_dynamo/repro/after_dynamo.py | 8 +- torch/compiler/__init__.py | 2 +- torch/export/experimental/__init__.py | 3 +- 7 files changed, 233 insertions(+), 177 deletions(-) diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 129984e6c10d3..6261679dcdef4 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -58,7 +58,7 @@ class _PyInterpreterFrame: f_globals: dict[str, object] f_builtins: dict[str, object] f_lasti: int - f_lineo: int + f_lineno: int f_back: types.FrameType # A tuple containing cell objects captured by this frame. closure: tuple[types.CellType] diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index d49f3c435e56d..5e2e2cb4106c3 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -116,7 +116,7 @@ def skip(fn=None): fn = innermost_fn(fn) assert callable(fn) skip_code(fn.__code__) - fn._torchdynamo_disable = True + fn._torchdynamo_disable = True # type: ignore[attr-defined] return fn diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 858aa402ca72d..e621d7082fe3f 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" """ @@ -112,9 +111,20 @@ if TYPE_CHECKING: - from torch._subclasses import fake_tensor + from collections.abc import Iterable, Sequence - from .types import CacheEntry, DynamoCallback + from torch._dynamo.package import CompilePackage + from torch._dynamo.repro.after_dynamo import WrapBackendDebug + from torch._subclasses import fake_tensor + from torch.fx.node import Argument, Node, Target + + from .types import ( + CacheEntry, + DynamoCallback, + DynamoFrameType, + GuardFail, + GuardFilterEntry, + ) log = logging.getLogger(__name__) @@ -134,7 +144,7 @@ class Unset(Enum): unset = Unset.token -def _maybe_set_eval_frame(callback: DynamoCallback): +def _maybe_set_eval_frame(callback: DynamoCallback) -> DynamoCallback: # A wrapper on set_eval_frame that is guarded by a Justknob. # Users can disable torchDynamo by setting the JK to False. if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): @@ -176,7 +186,7 @@ def _set_stance(stance: DynamoStance) -> DynamoStance: _EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None -def get_example_inputs(key) -> list[Any]: +def get_example_inputs(key: str) -> list[Any]: global _EXAMPLE_INPUTS if _EXAMPLE_INPUTS is None: _EXAMPLE_INPUTS = {} @@ -187,7 +197,7 @@ def get_example_inputs(key) -> list[Any]: return _EXAMPLE_INPUTS[key] -def _callback_from_stance(callback): +def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback: if _stance.stance == "default": # force_backend if _stance.backend is not None and callback not in (False, None): @@ -212,7 +222,9 @@ def _callback_from_stance(callback): if callback in (False, None): return callback - def fail_callback(frame, *args, **kwargs): + def fail_callback( + frame: DynamoFrameType, *args: Any, **kwargs: Any + ) -> ConvertFrameReturn: if trace_rules.check(frame.f_code): return ConvertFrameReturn() @@ -239,7 +251,9 @@ def fail_callback(frame, *args, **kwargs): raise RuntimeError(f"invalid torch.compile stance '{_stance}'") -def _create_wrapped_callback(compiler_fn): +def _create_wrapped_callback( + compiler_fn: CompilerFn, +) -> convert_frame.CatchErrorsWrapper: hooks = Hooks() return convert_frame.catch_errors_wrapper( convert_frame.convert_frame( # type: ignore[arg-type] @@ -250,7 +264,7 @@ def _create_wrapped_callback(compiler_fn): ) -def _get_or_add_example_inputs(frame): +def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]: key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno) example_inputs = get_example_inputs(key) @@ -260,8 +274,10 @@ def _get_or_add_example_inputs(frame): return example_inputs -def _create_delayed_compile_callback(callback, stance): - def callback_fn(*args, **kwargs): +def _create_delayed_compile_callback( + callback: DynamoCallback, stance: str +) -> Callable[..., Any]: + def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn: frame = args[0] example_inputs = _get_or_add_example_inputs(frame) @@ -278,7 +294,7 @@ def callback_fn(*args, **kwargs): dynamism = track_dynamism_across_examples(example_inputs) code_context.get_context(frame.f_code)["dynamism"] = dynamism - compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend + compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend # type: ignore[union-attr] return _create_wrapped_callback(compiler_fn)(*args, **kwargs) # to prevent cache miss due to different backend @@ -287,11 +303,11 @@ def callback_fn(*args, **kwargs): return callback_fn -def _is_skip_guard_eval_unsafe_stance(): +def _is_skip_guard_eval_unsafe_stance() -> bool: return _stance.skip_guard_eval_unsafe -def _reset_guarded_backend_cache(): +def _reset_guarded_backend_cache() -> None: global cached_backends for backend in cached_backends.values(): if hasattr(backend, "reset"): @@ -339,7 +355,7 @@ class OptimizedModule(torch.nn.Module): "_super_module_initialized", } - def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: + def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> None: # NOTE: this must go first, because attribute reads/writes of `self` # uses `_orig_mod`, and sometimes users override `Module.__init__` to # do attribute reads/writes on `self`. @@ -357,7 +373,7 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: self._initialize() self.training = self._orig_mod.training - def _initialize(self): + def _initialize(self) -> None: # Do this stuff in constructor to lower overhead slightly if isinstance(self.dynamo_ctx, DisableContext): # No need to check trace rules @@ -381,7 +397,7 @@ def _initialize(self): self._forward = self.forward self.forward = self._call_lazy_check - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: if torch.nn.modules.module._has_any_global_hook(): warnings.warn( "Using `torch.compile(module)` when there are global hooks on " @@ -394,37 +410,39 @@ def __call__(self, *args, **kwargs): ) return super().__call__(*args, **kwargs) - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[type[OptimizedModule], tuple[torch.nn.Module, _TorchDynamoContext]]: return (self.__class__, (self._orig_mod, self.dynamo_ctx)) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = dict(self.__dict__) state.pop("forward", None) state.pop("__call__", None) return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__ = state self._initialize() @property - def training(self): + def training(self) -> bool: return self._orig_mod.training @training.setter - def training(self, value): + def training(self, value: bool) -> None: # Ignore the `training` mutation in `super().__init__()`, since that's # setting the default on `nn.Module`, but we are mirroring the # `training` attr in `self._orig_mod`. if self._super_module_initialized: self._orig_mod.training = value - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name == "_orig_mod": return self._modules["_orig_mod"] return getattr(self._orig_mod, name) - def __setattr__(self, name, val) -> None: + def __setattr__(self, name: str, val: Any) -> None: # Allow patching over class attributes if hasattr(type(self), name): return super().__setattr__(name, val) @@ -433,7 +451,7 @@ def __setattr__(self, name, val) -> None: return super().__setattr__(name, val) return setattr(self._orig_mod, name, val) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: # This mirrors `__setattr__` if hasattr(type(self), name): return super().__delattr__(name) @@ -442,7 +460,7 @@ def __delattr__(self, name): return super().__delattr__(name) return delattr(self._orig_mod, name) - def _call_lazy_check(self, *args, **kwargs): + def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any: if ( hasattr(self._orig_mod, "_initialize_hook") and hasattr(self._orig_mod, "_infer_parameters") @@ -455,14 +473,14 @@ def _call_lazy_check(self, *args, **kwargs): self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) return self._forward(*args, **kwargs) - def __dir__(self): + def __dir__(self) -> list[str]: orig_mod_attrs = self._orig_mod.__dir__() return orig_mod_attrs + [ attr for attr in super().__dir__() if attr not in orig_mod_attrs ] -def remove_from_cache(f): +def remove_from_cache(f: Any) -> None: """ Make sure f.__code__ is not cached to force a recompile """ @@ -479,15 +497,17 @@ def remove_from_cache(f): log.warning("could not determine __code__ for %s", f) -def nothing(): +def nothing() -> None: pass -def always_false(): +def always_false() -> bool: return False -def innermost_fn(fn, unaltered_fn_attr="_torchdynamo_orig_callable"): +def innermost_fn( + fn: Callable[..., Any], unaltered_fn_attr: str = "_torchdynamo_orig_callable" +) -> Callable[..., Any]: """ In case of nesting of _TorchDynamoContext calls, find the innermost function. TorchDynamo caches on fn.__code__ object, so its necessary to find @@ -502,7 +522,7 @@ def innermost_fn(fn, unaltered_fn_attr="_torchdynamo_orig_callable"): return unaltered_fn -def make_set_enable_dynamic(enable: bool): +def make_set_enable_dynamic(enable: bool) -> Any: assert isinstance(enable, bool) if enable: # Assume everything is dynamic by default @@ -524,12 +544,12 @@ class DynamoTLS(threading.local): dynamo_tls = DynamoTLS() -def clear_dynamo_tls(): +def clear_dynamo_tls() -> None: dynamo_tls.traced_frame_infos.clear() @atexit.register -def _log_traced_frames(): +def _log_traced_frames() -> None: """ At program exit, log all of the frames Dynamo has attempted to trace from, excluding the continuation frames generated by Dynamo. @@ -540,7 +560,7 @@ def _log_traced_frames(): log.info(msg) -def guard_collectives_hook(guard_eval_result): +def guard_collectives_hook(guard_eval_result: bool) -> bool: import torch.distributed as dist from torch._dynamo.utils import dynamo_timed @@ -568,16 +588,18 @@ class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, - on_enter=nothing, - backend_ctx_ctor=null_context, - patch_fn=nothing, - first_ctx=False, + on_enter: Callable[[], Any] = nothing, + backend_ctx_ctor: Callable[ + [], contextlib.AbstractContextManager[Any] + ] = null_context, + patch_fn: Callable[[], Any] = nothing, + first_ctx: bool = False, *, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, - package=None, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, + package: Optional[CompilePackage] = None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -595,15 +617,15 @@ def __init__( patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset - backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") - cached_backends.setdefault(id(backend), backend) + backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") # type: ignore[arg-type] + cached_backends.setdefault(id(backend), backend) # type: ignore[arg-type] if dynamic is not None: self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) if on_enter is not nothing: # this case is not common - def call_on_enter(): + def call_on_enter() -> Callable[[], None]: on_enter() return nothing @@ -611,14 +633,14 @@ def call_on_enter(): if backend_ctx_ctor is not contextlib.nullcontext: # this case is not common - def call_backend_ctx(): + def call_backend_ctx() -> functools.partial[Optional[bool]]: ctx = backend_ctx_ctor() ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None) self.enter_exit_hooks.append(call_backend_ctx) - def __enter__(self): + def __enter__(self) -> None: if config.raise_on_ctx_manager_usage: raise RuntimeError( "torch._dynamo.optimize(...) is used with a context manager. " @@ -632,7 +654,12 @@ def __enter__(self): ) _maybe_set_eval_frame(_callback_from_stance(self.callback)) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> Optional[bool]: assert self.prior is not unset set_eval_frame(None) set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe) @@ -641,10 +668,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup_fns.clear() _maybe_set_eval_frame(_callback_from_stance(self.prior)) self.prior = unset + return None - def __call__(self, fn): + def __call__(self, fn: Any) -> Any: # public api for compiler config/options - def get_compiler_config(): + def get_compiler_config() -> Any: return self.compiler_config from .package import DynamoCache @@ -721,19 +749,18 @@ def get_compiler_config(): # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) - def do_nothing(*arg, **kwargs): + def do_nothing(*arg: Any, **kwargs: Any) -> None: pass + callback: Callable[..., Any] = do_nothing if hasattr(self, "callback"): - callback = self.callback - else: - callback = do_nothing + callback = self.callback # type: ignore[assignment] is_jit_tracing = torch._C._is_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing @functools.wraps(fn) - def compile_wrapper(*args, **kwargs): + def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: if is_fx_tracing(): @@ -861,20 +888,20 @@ def compile_wrapper(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): def __init__( self, - callback, - backend_ctx_ctor, - first_ctx=False, + callback: DynamoCallback, + backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]], + first_ctx: bool = False, *, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, - package=None, + package: Optional[CompilePackage] = None, ) -> None: - def on_enter(): + def on_enter() -> None: install_generation_tagging_init() super().__init__( @@ -895,7 +922,7 @@ def on_enter(): if _dynamic is None: _dynamic = not torch._dynamo.config.assume_static_by_default - def call_compiled_autograd(): + def call_compiled_autograd() -> functools.partial[Optional[bool]]: assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( @@ -906,7 +933,9 @@ def call_compiled_autograd(): self.enter_exit_hooks.append(call_compiled_autograd) - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[type[OptimizeContext], tuple[Any, ...], dict[str, Any]]: return ( self.__class__, (self.callback, self._backend_ctx_ctor, self.first_ctx), @@ -921,12 +950,12 @@ def __reduce__(self): class RunOnlyContext(_TorchDynamoContext): def __init__(self) -> None: # cudagraph trees relies on generation increment - def on_enter(): + def on_enter() -> None: torch._dynamo.mutation_guard.GenerationTracker.generation += 1 super().__init__(callback=False, on_enter=on_enter) - def __reduce__(self): + def __reduce__(self) -> tuple[type[RunOnlyContext], tuple[Any, ...]]: return (self.__class__, ()) @@ -936,7 +965,7 @@ def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None: self.msg = msg self.wrapping = wrapping - def __call__(self, fn): + def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: # Earlier this code was in the base class _TorchDynamoContext. But we # moved it here to have better code organization. For disable, we just # want the callback to be None. We don't have to check trace_rules or @@ -967,7 +996,7 @@ def __call__(self, fn): f"A callable function is expected, but {type(fn)} is provided." ) - def _fn(*args, **kwargs): + def _fn(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: _maybe_set_eval_frame(_callback_from_stance(self.callback)) @@ -995,21 +1024,23 @@ def _fn(*args, **kwargs): return _fn - def __reduce__(self): + def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]: return (self.__class__, ()) def _optimize_catch_errors( - compile_fn, + compile_fn: convert_frame.ConvertFrameProtocol, hooks: Hooks, - backend_ctx_ctor=null_context, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, - rebuild_ctx=None, - package=None, -): + backend_ctx_ctor: Callable[ + [], contextlib.AbstractContextManager[Any] + ] = null_context, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, + rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None, + package: Optional[CompilePackage] = None, +) -> OptimizeContext: return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, @@ -1023,11 +1054,17 @@ def _optimize_catch_errors( ) -def get_compiler_fn(compiler_fn): +def get_compiler_fn( + compiler_fn: Union[str, Callable[..., Any], None], +) -> WrapBackendDebug: from .repro.after_dynamo import wrap_backend_debug - if hasattr(compiler_fn, "compiler_name"): - compiler_str = compiler_fn.compiler_name + if compiler_fn is None: + # Special case None to avoid crashing in hasattr + compiler_str = None + elif hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name # type: ignore[union-attr] + assert isinstance(compiler_str, str) elif isinstance(compiler_fn, str): compiler_str = compiler_fn else: @@ -1037,14 +1074,14 @@ def get_compiler_fn(compiler_fn): class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] - def __call__(self, fn): + def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: assert callable(fn), ( f"A callable function is expected, but {type(fn)} is provided." ) return fn -def check_if_dynamo_supported(): +def check_if_dynamo_supported() -> None: if sys.version_info >= (3, 14): raise RuntimeError("Python 3.14+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( @@ -1058,7 +1095,7 @@ def check_if_dynamo_supported(): ) -def is_dynamo_supported(): +def is_dynamo_supported() -> bool: try: check_if_dynamo_supported() return True @@ -1066,11 +1103,11 @@ def is_dynamo_supported(): return False -def check_if_inductor_supported(): +def check_if_inductor_supported() -> None: check_if_dynamo_supported() -def is_inductor_supported(): +def is_inductor_supported() -> bool: try: check_if_inductor_supported() return True @@ -1078,15 +1115,15 @@ def is_inductor_supported(): return False -def check_for_incompatible_configs(): +def check_for_incompatible_configs() -> None: # Some of the configs should be mutually exclusive assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), ( "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." ) -def optimize(*args, **kwargs): - def rebuild_ctx(): +def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]: + def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]: ca_kwargs_override = config.compiled_autograd_kwargs_override if ca_kwargs_override: # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs @@ -1102,15 +1139,15 @@ def rebuild_ctx(): def _optimize( rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], - backend="inductor", + backend: Union[str, Callable[..., Any]] = "inductor", *, - nopython=False, - guard_export_fn=None, - guard_fail_fn=None, - guard_filter_fn=None, - disable=False, - dynamic=None, - package=None, + nopython: bool = False, + guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None, + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, + guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None, + disable: bool = False, + dynamic: Optional[bool] = None, + package: Optional[CompilePackage] = None, ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1195,8 +1232,10 @@ def toy_example(a, b): ... # TODO(voz): Consider making "explain" output alongside a run / part of a run @patch("torch._dynamo.symbolic_convert.explain", True) -def explain(f, *extra_args, **extra_kwargs): - def inner(*args, **kwargs): +def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any: + from .backends.debugging import ExplainOutput + + def inner(*args: Any, **kwargs: Any) -> ExplainOutput: # TODO(voz): Do we want a decorator for this? from . import reset # type: ignore[attr-defined] @@ -1209,8 +1248,8 @@ def inner(*args, **kwargs): out_guards: list[_guards.Guard] = [] def dynamo_graph_accumulating_compiler( - gm: torch.fx.GraphModule, example_inputs - ): + gm: torch.fx.GraphModule, example_inputs: Any + ) -> Callable[..., Any]: from .backends.debugging import _explain_graph_detail nonlocal graphs @@ -1224,7 +1263,7 @@ def dynamo_graph_accumulating_compiler( return gm.forward - def guard_export_print(guards): + def guard_export_print(guards: Iterable[_guards.Guard]) -> None: nonlocal out_guards out_guards.extend(guards) @@ -1242,7 +1281,6 @@ def guard_export_print(guards): # TODO(voz): Do we want a decorator for this? reset() - from .backends.debugging import ExplainOutput return ExplainOutput( graphs, @@ -1272,9 +1310,9 @@ class FlattenInputOutputSignature(torch.fx.Transformer): def __init__( self, m: torch.fx.GraphModule, - flat_args: tuple[Any], + flat_args: list[Any], matched_input_elements_positions: list[int], - flat_results: list[Any], + flat_results: Sequence[Any], matched_output_elements_positions: list[int], example_fake_inputs: list[torch.Tensor], flat_args_dynamic_dims: list[set[int]], @@ -1322,7 +1360,9 @@ def __init__( self.matched_output_elements_positions = matched_output_elements_positions self.flat_results = flat_results - def placeholder(self, target, args, kwargs): + def placeholder( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: arg = next(self.old_args_gen) if "val" in self.current_node.meta: arg.node.meta["val"] = self.current_node.meta["val"] @@ -1337,9 +1377,11 @@ def placeholder(self, target, args, kwargs): ] return arg - def output(self, target, args, kwargs): + def output( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: dynamo_result_flat = args[0] - lookup = [*dynamo_result_flat, *self.new_args] + lookup = [*dynamo_result_flat, *self.new_args] # type: ignore[misc] new_results_flat = [] for i in range(len(self.flat_results)): if self.matched_output_elements_positions[i] is not None: @@ -1352,7 +1394,7 @@ def output(self, target, args, kwargs): new_results_flat.append(const_val) return super().output(target, (new_results_flat,), {}) - def run_node(self, n): + def run_node(self, n: Node) -> Any: self.current_node = n result_proxy = super().run_node(n) if "val" in self.current_node.meta: @@ -1372,7 +1414,7 @@ def run_node(self, n): ) return result_proxy - def transform(self): + def transform(self) -> torch.fx.GraphModule: result_gm = super().transform() if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator] result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index] @@ -1391,15 +1433,17 @@ class ExportResult(NamedTuple): # NOTE: this function only supports graphs created by Dynamo's OutputGraph module -def check_signature_rewritable(graph): +def check_signature_rewritable(graph: torch.fx.GraphModule) -> None: input_errors = [] for node in graph.graph.find_nodes(op="placeholder"): # set in OutputGraph._call_user_compiler assert hasattr(node, "_dynamo_source") assert hasattr(graph, "_source_to_user_stacks") - source = node._dynamo_source - user_stacks = graph._source_to_user_stacks.get(source) + # NOTE: We can safely ignore these type warnings if and only if + # the function is made from OutputGraph (checked in the assertions) + source = node._dynamo_source # type: ignore[attr-defined] + user_stacks = graph._source_to_user_stacks.get(source) # type: ignore[operator, union-attr] if user_stacks is None: continue assert len(user_stacks) > 0 @@ -1436,20 +1480,22 @@ def check_signature_rewritable(graph): def rewrite_signature( - f_sig, - graph, - fake_mode, - flat_args, - in_spec, - example_fake_inputs, - graph_captured_input, - graph_captured_output, - dynamo_traced_result, - flat_args_dynamic_dims, -): + f_sig: inspect.Signature, + graph: torch.fx.GraphModule, + fake_mode: Optional[fake_tensor.FakeTensorMode], + flat_args: list[Any], + in_spec: pytree.TreeSpec, + example_fake_inputs: list[Any], + graph_captured_input: Iterable[Any], + graph_captured_output: Optional[Iterable[Any]], + dynamo_traced_result: Any, + flat_args_dynamic_dims: list[set[int]], +) -> torch.fx.GraphModule: orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) - def check_user_input_output(flat_values, error_type): + def check_user_input_output( + flat_values: list[Any], error_type: UserErrorType + ) -> None: supported_types = [ torch.Tensor, torch.SymInt, @@ -1459,7 +1505,7 @@ def check_user_input_output(flat_values, error_type): _IntWrapper, ] + list(common_constant_types) - def is_supported_type(val): + def is_supported_type(val: Any) -> bool: return isinstance(val, tuple(supported_types)) value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" @@ -1485,7 +1531,7 @@ def is_supported_type(val): flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) - def check_optional_input_and_error(f_sig: inspect.Signature): + def check_optional_input_and_error(f_sig: inspect.Signature) -> None: # Check if function has optional input. for name, param in f_sig.parameters.items(): if param.default is not inspect.Parameter.empty: @@ -1501,7 +1547,9 @@ def check_optional_input_and_error(f_sig: inspect.Signature): case_name="optional_input", ) - def produce_matching(debug_type, sources, candidates): + def produce_matching( + debug_type: str, sources: Iterable[Any], candidates: Iterable[Any] + ) -> list[Optional[int]]: matched_elements_positions: list[Optional[int]] = [] dict_of_source_vals = {} for i, val in enumerate(sources): @@ -1534,17 +1582,19 @@ def produce_matching(debug_type, sources, candidates): new_graph = FlattenInputOutputSignature( graph, flat_args, - matched_input_elements_positions, + matched_input_elements_positions, # type: ignore[arg-type] flat_results_traced, - matched_output_elements_positions, + matched_output_elements_positions, # type: ignore[arg-type] example_fake_inputs, flat_args_dynamic_dims, fake_mode, ).transform() # Make dynamo graph to have same input/output spec as user code - def argument_names(f_sig, args, kwargs) -> list[str]: - def signature_to_fullargspec(sig: inspect.Signature): + def argument_names( + f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any] + ) -> list[str]: + def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: # Get a list of Parameter objects from the Signature object params = list(sig.parameters.values()) # Separate positional arguments, keyword-only arguments and varargs/varkw @@ -1638,7 +1688,7 @@ def signature_to_fullargspec(sig: inspect.Signature): def export( f: Callable[..., Any], - *extra_args, + *extra_args: Any, aten_graph: bool = False, pre_dispatch: bool = False, decomposition_table: Optional[ @@ -1654,8 +1704,8 @@ def export( allow_complex_guards_as_runtime_asserts: bool = False, _log_export_usage: bool = True, constraints: Optional[list[Constraint]] = None, - **extra_kwargs, -) -> Callable[..., ExportResult]: + **extra_kwargs: Any, +) -> Callable[[tuple[Any, Any]], ExportResult]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. @@ -1718,7 +1768,7 @@ def export( _assume_static_by_default = assume_static_by_default _constraints = constraints - def inner(*args, **kwargs): + def inner(*args: Any, **kwargs: Any) -> ExportResult: if not _constraints: combined_args = _combine_args(_f, args, kwargs) constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) @@ -1738,7 +1788,7 @@ def inner(*args, **kwargs): assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" f = innermost_fn(f) call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f - original_signature = inspect.signature(call_to_inspect) + original_signature = inspect.signature(call_to_inspect) # type: ignore[arg-type] graph = None out_guards = None graph_captured_input = None @@ -1746,18 +1796,18 @@ def inner(*args, **kwargs): fake_mode = None result_traced = None - def guard_export_print(guards: _guards.GuardsSet): + def guard_export_print(guards: _guards.GuardsSet) -> None: nonlocal out_guards assert out_guards is None, ( "whole graph export entails exactly one guard export" ) out_guards = guards - example_inputs = [] + example_inputs: list[Any] = [] def dynamo_normalization_capturing_compiler( - gm: torch.fx.GraphModule, inner_example_inputs - ): + gm: torch.fx.GraphModule, inner_example_inputs: list[Any] + ) -> Callable[..., Any]: nonlocal graph assert graph is None, ( "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." @@ -1773,7 +1823,7 @@ def dynamo_normalization_capturing_compiler( fake_mode = _guards.detect_fake_mode() example_inputs = inner_example_inputs - def result_capturing_wrapper(*graph_inputs): + def result_capturing_wrapper(*graph_inputs: Any) -> Any: nonlocal graph_captured_result nonlocal graph_captured_input @@ -1815,7 +1865,14 @@ def result_capturing_wrapper(*graph_inputs): value, static_shapes=True ) - def fakify_with_ambient(path, t): + from torch._export.non_strict_utils import ( + key_path_to_source, + KeyPath, + ) + + def fakify_with_ambient( + path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any] + ) -> Any: if isinstance(t, torch.Tensor): return ambient_fake_mode.from_tensor(t, static_shapes=True) elif isinstance(t, _IntWrapper): @@ -1828,10 +1885,6 @@ def fakify_with_ambient(path, t): _DimHintType.AUTO, ) ): # type: ignore[union-attr] - from torch._export.non_strict_utils import ( - key_path_to_source, - ) - source = key_path_to_source(path) symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] t.val, source, DimDynamic.DYNAMIC @@ -1989,7 +2042,7 @@ def fakify_with_ambient(path, t): if aten_graph: # Running graph with interpreter is needed for propagating the stack_trace - def graph_with_interpreter(*args): + def graph_with_interpreter(*args: Any) -> Any: with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] @@ -2039,12 +2092,12 @@ def graph_with_interpreter(*args): flat_args, in_spec, example_fake_inputs, - graph_captured_input, + graph_captured_input, # type: ignore[arg-type] graph_captured_result, result_traced, # type: ignore[possibly-undefined] flat_args_dynamic_dims, ) - return ExportResult(graph, out_guards) # type: ignore[arg-type] + return ExportResult(graph, out_guards) if extra_args or extra_kwargs: warnings.warn( @@ -2054,19 +2107,19 @@ def graph_with_interpreter(*args): FutureWarning, stacklevel=2, ) - return inner(*extra_args, **extra_kwargs) + return inner(*extra_args, **extra_kwargs) # type: ignore[return-value] else: return inner -def optimize_assert(*args, **kwargs): +def optimize_assert(*args: Any, **kwargs: Any) -> OptimizeContext: if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None: # called from optimize rebuild_ctx = kwargs["rebuild_ctx"] del kwargs["rebuild_ctx"] else: - def rebuild_ctx(): + def rebuild_ctx() -> OptimizeContext: return optimize_assert(*args, **kwargs) return _optimize_assert(rebuild_ctx, *args, **kwargs) @@ -2074,14 +2127,14 @@ def rebuild_ctx(): def _optimize_assert( rebuild_ctx: Callable[[], OptimizeContext], - backend, + backend: Union[str, Callable[..., Any], None], *, - hooks=Hooks(None, None, None), - export=False, - export_constraints=None, - dynamic=None, - package=None, -): + hooks: Hooks = Hooks(None, None, None), + export: bool = False, + export_constraints: Optional[Any] = None, + dynamic: Optional[bool] = None, + package: Optional[CompilePackage] = None, +) -> OptimizeContext: """ The same as `torch._dynamo.optimize(backend, nopython=True)`, but ignores symbolic_convert.error_on_graph_break setting. @@ -2123,7 +2176,7 @@ def _optimize_assert( class TorchPatcher: @staticmethod @functools.cache - def patch(): + def patch() -> None: # A better way to disable the following would be decorate the source # functions with @torch._disable_dynamo. However, this causes issues # with torch.deploy internally. @@ -2216,8 +2269,10 @@ def patch(): ) @staticmethod - def suppress_torch_distributed_warnings(fn): - def inner_fn(*args, **kwargs): + def suppress_torch_distributed_warnings( + fn: Callable[..., Any], + ) -> Callable[..., Any]: + def inner_fn(*args: Any, **kwargs: Any) -> Any: warnings.filterwarnings( "ignore", category=UserWarning, module="torch.distributed" ) @@ -2226,7 +2281,7 @@ def inner_fn(*args, **kwargs): return inner_fn -def skip_code(code: types.CodeType): +def skip_code(code: types.CodeType) -> None: set_code_exec_strategy( code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 2a33a019b6cca..be750d41a1dc9 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -219,7 +219,7 @@ def initialize( assert not self._initialized self._inlined_sources = set() - self._innermost_fn = innermost_fn(fn) + self._innermost_fn = innermost_fn(fn) # type: ignore[assignment] assert self._innermost_fn is not None if dynamo is not None: assert isinstance(dynamo, _DynamoCacheEntry) diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 80191f2d6cefc..86a33677eb14d 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -27,7 +27,7 @@ import sys import textwrap from importlib import import_module -from typing import Union +from typing import Optional, Union import torch import torch.fx as fx @@ -79,7 +79,7 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): class WrapBackendDebug: - def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: + def __init__(self, unconfigured_compiler_fn, compiler_name: Optional[str]) -> None: functools.wraps(unconfigured_compiler_fn)(self) self._torchdynamo_orig_backend = unconfigured_compiler_fn # type: ignore[attr-defined] self._compiler_name = compiler_name @@ -152,7 +152,7 @@ def add_paths(exc): return compiled_gm -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: Optional[str]): """ A minifier decorator that wraps the TorchDynamo produced Fx graph modules. As opposed to wrap_compiler_debug, this wrapper intercepts at the @@ -455,7 +455,7 @@ def repro_run(options, mod, load_args): if options.accuracy != "": mod.eval() - opt_mod.eval() + opt_mod.eval() # type: ignore[union-attr] with torch.amp.autocast("cuda", enabled=options.autocast): # TODO: disable clone diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 578b56e504e20..e92100e87f384 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -355,7 +355,7 @@ def set_enable_guard_collectives(enabled: bool): from torch._dynamo.eval_frame import guard_collectives_hook if enabled: - return set_guard_complete_hook(guard_collectives_hook) is not None + return set_guard_complete_hook(guard_collectives_hook) is not None # type: ignore[arg-type] else: return set_guard_complete_hook(None) is not None diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index b1c86abc69ce7..b34bef61b508b 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -316,7 +316,8 @@ def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] if isinstance(fn, torch.nn.Module): _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 - fn, lambda _: _exporter_context + fn, + lambda _: _exporter_context, # type: ignore[arg-type] ) def _define_overload( From 1d584761622ff6e5519c5e3dbbb62a21b89ffe8a Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 15 Jul 2025 11:42:08 -0700 Subject: [PATCH 146/457] [PP] Add eval() API to schedule (#157795) These change add an `eval()` API to PP schedules ## Context Currently, you can run "Forward only" for a schedule in two ways: 1. Use a custom schedule `_ScheduleForwardOnly` 2. Do not pass in `loss_fn` in schedule constructor, and no backward computations will be executed. However, this is still limiting because we may want to run forward through the pipeline / calculate the loss, but without backward, e.g. during validation. These changes allow for this. ```python if self.rank == 0: schedule.eval(x) elif self.rank == self.world_size - 1: losses = [] schedule.eval(target=target, losses=losses) else: schedule.eval() ``` TODO: - in later PRs, we will deprecate the `_ScheduleForwardOnly` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157795 Approved by: https://github.com/wconstab --- .../pipelining/test_schedule_multiproc.py | 132 +++++++++++++++++- torch/distributed/pipelining/schedules.py | 59 ++++++-- torch/distributed/pipelining/stage.py | 7 +- 3 files changed, 181 insertions(+), 17 deletions(-) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 50aa9ff21ba08..41967a0e58249 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -110,6 +110,137 @@ def test_forward_only(self, ScheduleClass): torch.testing.assert_close(x_clone, out) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize( + "ScheduleClass", + [ScheduleGPipe, Schedule1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS], + ) + def test_eval_inference_mode(self, ScheduleClass): + if ScheduleClass in [ScheduleInterleaved1F1B, ScheduleLoopedBFS]: + # Multi-stage schedules + stages_per_rank = 2 + n_stages = stages_per_rank * self.world_size + mod = MultiMLP(d_hid, n_layers=n_stages) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + mod.get_submodule(submod_name) for submod_name in submod_names + ] + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + + # Test with eval() method for inference + schedule = ScheduleClass(stages, chunks, loss_fn=loss_fn, scale_grads=False) + + # Clear gradients + for stage_module in stage_modules: + stage_module.zero_grad() + + if self.rank == 0: + schedule.eval(x) + elif self.rank == self.world_size - 1: + losses = [] + schedule.eval(target=target, losses=losses) + else: + schedule.eval() + + # Check that gradients were NOT computed during eval + grad_computed_eval = False + for stage_module in stage_modules: + for param in stage_module.parameters(): + if param.grad is not None: + grad_computed_eval = True + break + if grad_computed_eval: + break + + # Verify that gradients were not computed during eval + self.assertFalse( + grad_computed_eval, + "Gradients should not be computed during eval()", + ) + + # Verify that losses are still computed during eval + if self.rank == self.world_size - 1: + self.assertTrue( + len(losses) > 0, "Losses should be computed during eval()" + ) + else: + # Single-stage schedules + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + x_mb = x.chunk(chunks)[0] + + # Create a pipeline + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Test with eval() method for inference + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) + + # Get stage module for gradient checking + stage_module = pipe.get_stage_module(self.rank) + stage_module.zero_grad() + + if self.rank == 0: + schedule.eval(x) + elif self.rank == self.world_size - 1: + losses = [] + schedule.eval(target=target, losses=losses) + else: + schedule.eval() + + # Check that gradients were NOT computed during eval + grad_computed_eval = False + for param in stage_module.parameters(): + if param.grad is not None: + grad_computed_eval = True + break + + # Verify that gradients were not computed during eval + self.assertFalse( + grad_computed_eval, + "Gradients should not be computed during eval()", + ) + + # Verify that losses are still computed during eval + if self.rank == self.world_size - 1: + self.assertTrue( + len(losses) > 0, "Losses should be computed during eval()" + ) + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @@ -1048,6 +1179,5 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): instantiate_parametrized_tests(ScheduleTest) - if __name__ == "__main__": run_tests() diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d60e9bd307813..f4ded5a1f0bcf 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -248,13 +248,13 @@ def __init__( logger.info("Using %s", self.__class__.__name__) def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): - if stage.is_last and self._has_backward: + if stage.is_last and self._loss_fn is not None: loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] self._internal_losses.append(loss) def _maybe_get_loss(self, stage, mb_index): valid_index = 0 <= mb_index < len(self._internal_losses) - if stage.is_last and self._has_backward and valid_index: + if stage.is_last and self._loss_fn is not None and valid_index: return self._internal_losses[mb_index] elif len(self._internal_losses) != 0 and not valid_index: raise RuntimeError( @@ -319,6 +319,26 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ raise NotImplementedError + def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches, calling forward only. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target values for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Save the original has_backward state + original_has_backward = self._has_backward + try: + self._has_backward = False + return self.step(*args, target=target, losses=losses, **kwargs) + finally: + # Restore the original state + self._has_backward = original_has_backward + def _check_inputs( self, arg_mbs: Optional[list] = None, @@ -475,8 +495,6 @@ def __init__( # Self attributes self._stage = stage self._num_stages = stage.num_stages - # Set the same has_backward flag for stage object - self._stage.has_backward = self._has_backward self._stage_initialized = False if n_microbatches < self._num_stages: @@ -506,6 +524,15 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ + if not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward # Clean per iteration self._stage.clear_runtime_states() @@ -650,10 +677,6 @@ def _step_microbatches( for work in fwd_sends_to_wait: _wait_batch_p2p(work) - # No loss function, no need to run backward - if not self._has_backward: - return - # Run backward # Delay send waits bwd_sends_to_wait: list[list[dist.Work]] = [] @@ -681,13 +704,13 @@ def _step_microbatches( grad_scale_factor=self._n_microbatches if self.scale_grads else 1 ) - # Return losses if there is a container passed in - self._update_losses(self._stage, losses) - # Wait for all backward sends to finish for work in bwd_sends_to_wait: _wait_batch_p2p(work) + # Update losses if there is a container passed in + self._update_losses(self._stage, losses) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: """ Returns the pipeline order for GPipe schedule. @@ -1264,9 +1287,6 @@ def __init__( for stage in self._stages: stage.stage_index_to_group_rank = self.stage_index_to_group_rank - # Set the same has_backward flag for stage object - for stage in self._stages: - stage.has_backward = self._has_backward self._stages_initialized = False # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle @@ -1349,6 +1369,17 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ + if not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + # Clean per iteration for stage in self._stages: stage.clear_runtime_states() diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index df229c9832090..e22799545903e 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -462,11 +462,10 @@ def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ - self._check_chunk_id(bwd_chunk_id) - if not self.has_backward or self.is_first: return [] + self._check_chunk_id(bwd_chunk_id) # Create bwd send infra lazily if self.grad_send_info is None: # Send info for input grads during backward: @@ -761,6 +760,10 @@ def backward_one_chunk( last_backward is controlled by the schedule and signals synchronization of gradients across DP groups after the last backward. """ + # skip backward computation if backward is not enabled + if not self.has_backward: + return + self._check_chunk_id(bwd_chunk_id) ( From 306dd19216b656467143483395ef582feb5d7d07 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 16 Jul 2025 15:27:13 -0700 Subject: [PATCH 147/457] update expeced results (#158497) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158497 Approved by: https://github.com/xmfan --- .../pr_time_benchmarks/expected_results.csv | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 1b86e02b8afda..24f0b2af088c2 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,2994000000,0.015 +add_loop_eager,compile_time_instruction_count,3051000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,4405000000,0.025 -add_loop_inductor,compile_time_instruction_count,33260000000,0.015 +add_loop_inductor,compile_time_instruction_count,33810000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42900000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43470000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,29880000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,30390000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,965100000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17940000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18300000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17210000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17630000000,0.015 @@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10980000 -update_hint_regression,compile_time_instruction_count,1688000000,0.02 +update_hint_regression,compile_time_instruction_count,1717000000,0.02 -sum_floordiv_regression,compile_time_instruction_count,992700000,0.015 +sum_floordiv_regression,compile_time_instruction_count,965000000,0.015 -symint_sum,compile_time_instruction_count,3187000000,0.015 +symint_sum,compile_time_instruction_count,3239000000,0.015 -symint_sum_loop,compile_time_instruction_count,4225000000,0.015 +symint_sum_loop,compile_time_instruction_count,4305000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2122000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2146000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6040000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6119000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8894000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8976000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1952000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1988000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3905000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3951000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10640000000,0.015 -mm_loop_inductor_gpu,compile_time_instruction_count,4406000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4468000000,0.015 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8274000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8400000000,0.015 -basic_NestedModule_eager,compile_time_instruction_count,8193000000,0.015 +basic_NestedModule_eager,compile_time_instruction_count,8357000000,0.015 -basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 +basic_InlineMod_eager,compile_time_instruction_count,7443000000,0.015 From 82a1ee1135b054d371d10081883b848ac7b7419f Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 17 Jul 2025 00:23:00 +0000 Subject: [PATCH 148/457] Refactor Provenance Tracking (#158399) Summary: As inductor provenance tracking is getting more use cases, we want to separate the inductor provenance tracking guarding flag from the general `trace.enabled`, so we can enable provenance tracking without all the overhead of `trace.enabled` - change the guard flag from `trace.enabled` to `trace.provenance_tracking`. It is turned on by either `TORCH_COMPILE_DEBUG=1` or `INDUCTOR_PROVENANCE=1`. - Move the provenance tracking logic and variables out of DebugContext, because DebugContext is only enabled with `trace.enabled`. Since the variables are now global variables, added `reset_provenance_globals()` context manager to reset them for each `compile_fx()` call. - Move `set_kernel_post_grad_provenance_tracing` from `util.py` to `debug.py` so now all provenance related logic is in `debug.py`. In the future, if we want to enable it further, we can change the provenance tracking flag to be enabled when `TORCH_TRACE` is set. I think we should do that in a separate PR, so it's easier to revert if this flag change creates any problem. See more motivation in internal Diff Test Plan: ``` buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer buck run mode/dev-nosan fbcode//caffe2/test:fx -- -r graph_provenance buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing ``` Differential Revision: D78287976 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158399 Approved by: https://github.com/angelayi --- test/dynamo/test_structured_trace.py | 8 +- test/fx/test_fx_xform_observer.py | 4 +- test/inductor/test_provenance_tracing.py | 24 +++- torch/_inductor/codegen/cpp.py | 4 +- torch/_inductor/codegen/simd.py | 8 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/compile_fx.py | 46 ++++---- torch/_inductor/config.py | 7 +- torch/_inductor/debug.py | 124 ++++++++++++++++---- torch/_inductor/pattern_matcher.py | 2 +- torch/_inductor/utils.py | 46 +------- torch/fx/experimental/proxy_tensor.py | 2 +- torch/fx/passes/graph_transform_observer.py | 2 +- 13 files changed, 162 insertions(+), 119 deletions(-) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 69f0203adf06f..cde880df17a62 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -1040,10 +1040,10 @@ def backward(ctx, gO): '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 9, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 1, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 7, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 10, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 9, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 12, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 1, "attempt": 0}', ] logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in expected)) diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index 2517439d9fe36..10577712196b3 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -55,7 +55,7 @@ def replacement(x): ) ) - @torch._inductor.config.patch("trace.enabled", True) + @torch._inductor.config.patch("trace.provenance_tracking", True) def test_graph_transform_observer_node_tracking(self): class M(torch.nn.Module): def forward(self, x): @@ -156,7 +156,7 @@ def forward(self, x): [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], ) - @torch._inductor.config.patch("trace.enabled", True) + @torch._inductor.config.patch("trace.provenance_tracking", True) def test_graph_transform_observer_deepcopy(self): class SimpleLinearModel(torch.nn.Module): def forward(self, x): diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index c0efa7416ae1a..1f7cd7a9f2c00 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -56,6 +56,7 @@ def forward(self, a): @config.patch("trace.enabled", True) +@config.patch("trace.provenance_tracking", True) class TestProvenanceTracingArtifact(TestCase): """ This test checks that generated provenance tracing artifact from "post_grad" to @@ -121,6 +122,10 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): "mul_2", ], } + if backend == "aot_inductor": + expected_data["aoti_torch_cuda_mm_out"] = ["mm_default"] + else: + expected_data["extern_kernels.mm"] = ["mm_default"] self._check_provenance_tracing_artifact(filepath, expected_data) expected_mapping = [ ( @@ -171,6 +176,16 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): }, ), ] + if backend == "aot_inductor": + expected_mapping[0][1]["aoti_torch_cuda_mm_out"] = [ + "mm_default" + ] + expected_mapping[1][1]["mm_default"] = [ + "aoti_torch_cuda_mm_out" + ] + else: + expected_mapping[0][1]["extern_kernels.mm"] = ["mm_default"] + expected_mapping[1][1]["mm_default"] = ["extern_kernels.mm"] self._check_provenance_tracking_node_mappings( filepath, expected_mapping ) @@ -180,7 +195,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): if backend == "aot_inductor": expected_data = { "cpp_fused_mul_0": ["mul"], - "aoti_torch_cpu_addmm_out": ["addmm", "mul"], + "aoti_torch_cpu_addmm_out": ["addmm"], "cpp_fused_gelu_1": [ "mul_3", "mul_1", @@ -193,7 +208,6 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): # backend == "inductor" expected_data = { "cpp_fused_mul_0": ["mul"], - "aoti_torch_cpu_addmm_out": ["addmm", "mul"], "cpp_fused_gelu_1": [ "mul_3", "mul_1", @@ -201,7 +215,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): "erf", "mul_2", ], - "extern_kernels.addmm": ["addmm", "mul"], + "extern_kernels.addmm": ["addmm"], } self._check_provenance_tracing_artifact(filepath, expected_data) @@ -252,14 +266,12 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): filepath = Path(m.group(1)) if backend == "inductor": expected_data = { - "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], - "triton_poi_fused_0": ["_tensor_constant1"], "extern_kernels.addmm": ["addmm"], } else: # backend = aot_inductor expected_data = { - "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], + "aoti_torch_cuda_addmm_out": ["addmm"], "triton_poi_fused_0": ["_tensor_constant1"], } self._check_provenance_tracing_artifact(filepath, expected_data) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 4b15618c12bf0..12584284631b7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -24,6 +24,7 @@ from ..._dynamo.utils import counters from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +from ..debug import set_kernel_post_grad_provenance_tracing from ..loop_body import LoopBody from ..scheduler import ( BaseSchedulerNode, @@ -43,7 +44,6 @@ is_welford_reduction, parallel_num_threads, Placeholder, - set_kernel_post_grad_provenance_tracing, sympy_index_symbol, sympy_index_symbol_with_prefix, sympy_product, @@ -5191,7 +5191,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): ) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) # below add provenance tracing info for cpu CppKernel types - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing(nodes, kernel_name) kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 90bee26c09249..42c9a9d89eb99 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from ..ir import IRNode +from ..debug import set_kernel_post_grad_provenance_tracing from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.runtime_utils import green_text, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse @@ -51,7 +52,6 @@ IndentedBuffer, Placeholder, prefix_is_reduction, - set_kernel_post_grad_provenance_tracing, sympy_index_symbol, sympy_product, sympy_subs, @@ -1453,7 +1453,7 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): with V.set_kernel_handler(kernel): src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( node_schedule, # type: ignore[arg-type] kernel_name, @@ -1659,7 +1659,7 @@ def _codegen_single_template( kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( node_schedule, kernel.kernel_name ) @@ -1844,7 +1844,7 @@ def codegen_combo_kernel(self, combo_kernel_node): for src_code, kernel, _ in kernel_code_list: kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) # dump provenance node info for ComboKernelNode/ForeachKernel type - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( combo_kernel_node.snodes, kernel_name ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e7726263714fa..0b8ba86c3c185 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -40,6 +40,7 @@ from .. import async_compile, config, ir from ..codecache import output_code_log +from ..debug import set_kernel_post_grad_provenance_tracing from ..ir import IRNode, ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -50,7 +51,6 @@ IndentedBuffer, is_codegen_graph_partition_subgraph, LineContext, - set_kernel_post_grad_provenance_tracing, sympy_product, sympy_str, sympy_subs, @@ -479,7 +479,7 @@ def codegen(self, code: IndentedBuffer) -> None: kernel_name = node.get_kernel_name() device = d.type if (d := node.get_device()) else V.graph.device_type # set provenance tracing kernel mapping for ExternKernel types - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) self.wrapper._generate_extern_kernel_out_helper( kernel_name, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bfdb9a54e56f6..e20ae1d85ae3b 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1028,30 +1028,23 @@ def _compile_fx_inner( log.debug("FX codegen and compilation took %.3fs", time.time() - start) - # Dump provenance artifacts for debugging trace - provenance_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info() - # provenance_info might be None if config.trace.enabled is not set - if provenance_info: - ( - debug_info, - node_mappings, - ) = provenance_info - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_generated_kernel_to_post_grad_nodes", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(debug_info), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_provenance_tracking_node_mappings", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(node_mappings), - ) + if config.trace.provenance_tracking: + # Dump provenance artifacts for debugging trace + provenance_info = torch._inductor.debug.dump_inductor_provenance_info() + # provenance_info might be None if trace.provenance_tracking is not set + if provenance_info: + ( + _, + node_mappings, + ) = provenance_info + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(node_mappings), + ) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering if log.isEnabledFor(logging.INFO): @@ -1294,7 +1287,7 @@ def codegen_and_compile( }, payload_fn=lambda: inductor_post_grad_graph_str, ) - if config.trace.enabled: + if config.trace.provenance_tracking: provenance_tracking_json = ( torch.fx.traceback.get_graph_provenance_json(gm.graph) ) @@ -2147,7 +2140,8 @@ def compile_fx( with ( _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), enable_python_dispatcher(), - torch.fx.traceback.preserve_node_meta(config.trace.enabled), + torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking), + torch._inductor.debug.reset_provenance_globals(), ): # Pre-grad passes cannot be run if we weren't given a GraphModule. # Dynamo will always produce a GraphModule, but this handles cases diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5c7a53683db3b..6e77283aacf2e 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1773,8 +1773,11 @@ class trace: log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" - # Save mapping info from inductor generated triton kernel to post_grad fx nodes - log_inductor_triton_kernel_to_post_grad_node_info: bool = True + # Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes + provenance_tracking = ( + os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1" + ) _save_config_ignore: list[str] = [ diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index d3bc89a3d4125..f21e0be24d54d 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -13,7 +13,7 @@ import pstats import shutil import traceback -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import Any, Callable, IO, Optional, Union from unittest.mock import patch @@ -31,6 +31,7 @@ from torch.utils._pytree import tree_map from . import config, ir # noqa: F811, this is needed +from .ir import ExternKernelOut from .scheduler import ( BaseSchedulerNode, FusedSchedulerNode, @@ -313,15 +314,44 @@ def enable_aot_logging() -> Iterator[None]: # They are not stored in DebugContext because they are not set in # _inductor_triton_kernel_to_post_grad_node_info's Debug Context _inductor_post_to_pre_grad_nodes: dict[str, Any] = {} +_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {} _pre_grad_graph_id: Optional[int] = None +@contextlib.contextmanager +def reset_provenance_globals() -> Iterator[None]: + """Context manager that resets provenance tracking globals upon entering + and restores their original values when exiting.""" + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + + # Store original values + original_pre_grad_graph_id = _pre_grad_graph_id + original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy() + original_triton_kernel_to_post_grad_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.copy() + ) + + # Reset to default values + _pre_grad_graph_id = -1 + _inductor_post_to_pre_grad_nodes = {} + _inductor_triton_kernel_to_post_grad_node_info = {} + + try: + yield + finally: + # Restore original values + _pre_grad_graph_id = original_pre_grad_graph_id + _inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes + _inductor_triton_kernel_to_post_grad_node_info = ( + original_triton_kernel_to_post_grad_node_info + ) + + class DebugContext: _counter = itertools.count() - # Used for provenance tracking - _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} - @staticmethod def create_debug_dir(folder_name: str) -> Optional[str]: debug_dir = config.trace.debug_dir or get_debug_dir() @@ -557,25 +587,6 @@ def draw_orig_fx_graph( def output_code(self, filename: str, extension: str = "py") -> None: shutil.copy(filename, self.filename(f"output_code.{extension}")) - def log_inductor_triton_kernel_to_post_grad_node_info( - self, filename: str = "inductor_generated_kernel_to_post_grad_nodes.json" - ) -> tuple[dict[str, list[str]], dict[str, Any]]: - debug_info = {} - with self.fopen(filename, "w") as fd: - log.info("Writing provenance tracing debugging info to %s", fd.name) - debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info - json.dump(debug_info, fd) - node_mapping = {} - if _pre_grad_graph_id: - with self.fopen( - "inductor_provenance_tracking_node_mappings.json", "w" - ) as fd: - node_mapping = create_node_mapping( - _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info - ) - json.dump(node_mapping, fd) - return debug_info, node_mapping - def log_autotuning_results( self, name: str, @@ -808,6 +819,73 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: return empty_return +def dump_inductor_provenance_info( + filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", +) -> tuple[dict[str, list[str]], dict[str, Any]]: + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + debug_info = _inductor_triton_kernel_to_post_grad_node_info.copy() + if config.trace.enabled: + with V.debug.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + json.dump(debug_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + node_mapping = create_node_mapping( + _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info + ) + if config.trace.enabled: + with V.debug.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + json.dump(node_mapping, fd) + return debug_info, node_mapping + + +def set_kernel_post_grad_provenance_tracing( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], + kernel_name: str, + is_extern: bool = False, +) -> None: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction + + global _inductor_triton_kernel_to_post_grad_node_info + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. + # "origin_node" is more precise and says that the contents of this node corresponds + # EXACTLY to the output of a particular FX node, but it's not always available + if node_schedule.origin_node: + origin_node_name = node_schedule.origin_node.name + if origin_node_name not in curr_node_info: + curr_node_info.append(origin_node_name) + else: + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + + def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: """ This function is used to save arguments for a compile_fx_inner function call diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index b13a058324d41..78aa947ea7f6a 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -127,7 +127,7 @@ def _transfer_meta( # transfer metadata after pattern matching occurs. # skip "val" and "tensor_meta" because this info is too specific; it's unlikely # to remain accurate after pattern matching has occurred. - if config.trace.enabled: + if config.trace.provenance_tracking: # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. new_from_node = new_meta.get("from_node", []).copy() new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0b82dfda835a9..5f9ce0b814eba 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -81,15 +81,7 @@ from .codegen.common import WorkspaceArg from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering - from .ir import ( - Buffer, - ExternKernel, - ExternKernelOut, - IRNode, - Layout, - Operation, - ReinterpretView, - ) + from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView from .output_code import CompiledFxGraph from .scheduler import BaseSchedulerNode, SchedulerBuffer @@ -3131,42 +3123,6 @@ def get_donated_idxs() -> Optional[list[int]]: return None -def set_kernel_post_grad_provenance_tracing( - node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], - kernel_name: str, - is_extern: bool = False, -) -> None: - from .codegen.simd_kernel_features import DisableReduction, EnableReduction - from .ir import ExternKernelOut - from .virtualized import V - - if is_extern: - assert isinstance(node_schedule, ExternKernelOut) - curr_node_info = ( - V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - ) - curr_node_info.extend( - origin.name - for origin in node_schedule.origins - if origin.name not in curr_node_info - ) - else: - assert isinstance(node_schedule, list) - for snode in node_schedule: - if snode not in (EnableReduction, DisableReduction): - if snode.node is not None: - curr_node_info = V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - curr_node_info.extend( - origin.name - for origin in snode.node.origins - if origin.name not in curr_node_info - ) - - class TritonAttrsDescriptorVersion(enum.Enum): V0_NO_TRITON = 0 V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 6777b1f31cef2..a578723ea1cbb 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -2343,7 +2343,7 @@ def make_fx( record_module_stack, _allow_fake_constant, _error_on_data_dependent_ops, - record_stack_traces=record_stack_traces or config.trace.enabled, + record_stack_traces=record_stack_traces or config.trace.provenance_tracking, ) @functools.wraps(f) diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 8e59fc7ae1793..29afd600b5fd9 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -42,7 +42,7 @@ def __init__( self.log_url = log_url - self.active = trace.enabled or self.log_url is not None + self.active = trace.provenance_tracking or self.log_url is not None if self.active: self.erased_nodes: set[str] = set() From e78f2ac92b709a060aa323d6e527ec2ecc36fb93 Mon Sep 17 00:00:00 2001 From: clr Date: Wed, 16 Jul 2025 14:49:48 -0700 Subject: [PATCH 149/457] inductor: Fix crash in split_cat when tensors is a Node (#157155) If there is only one node passed to aten::cat, the argument is a single node, rather than a list of nodes with a valid length. Example stack ``` File "/dev/shm/uid-99/be3468a8-seed-nspid4026546656_cgpid14993614-ns-4026546628/torch/_inductor/pattern_matcher.py", line 1115, in apply self.handler(match, *match.args, **match.kwargs) File "/dev/shm/uid-99/be3468a8-seed-nspid4026546656_cgpid14993614-ns-4026546628/torch/_inductor/fx_passes/split_cat.py", line 1786, in merge_split_cat_aten if len(cat_inputs) < threshold_to_cat: torch._inductor.exc.InductorError: TypeError: object of type 'Node' has no len() ``` This has failed about 7 internal jobs in the last week, running pytorch trunk code from 06/15 I've attached a test which reproduces this issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157155 Approved by: https://github.com/jansel --- .../inductor/test_split_cat_fx_aten_passes.py | 42 +++++++++++++++++++ torch/_inductor/fx_passes/split_cat.py | 6 ++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 9300c7d0d1264..354552c497d98 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -49,6 +49,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.ops.aten.cat.default([cat_1, cat_2], 1) +class TestSplitCatSingular(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): + cat = torch.ops.aten.cat.default([x, y], 1) + split = torch.ops.aten.split.Tensor(cat, 32, 1) + getitem = split[0] + cat_1 = torch.ops.aten.cat.default( + [getitem], + 1, + ) + cat_2 = torch.ops.aten.cat.default([getitem, z], 1) + return torch.ops.aten.cat.default([cat_1, cat_2], 1) + + class TestSplitCatPartial(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -275,6 +291,32 @@ def test_split_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "normalization_aten_pass": {}, + "split_cat_aten_pass": {"threshold_to_cat": 5}, + }, + ) + def test_split_cat_post_grad_singular(self): + counters.clear() + inputs = [ + torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)), + torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)), + torch.randn(1024, 32, device=torch.device(device=GPU_TYPE)), + ] + module = TestSplitCatSingular() + traced = torch.compile(module) + ref = module(*inputs) + res = traced(*inputs) + self.compare_pred(module, traced, inputs) + self.assertEqual(counters["inductor"]["normalization_aten_pass"], 4) + self.assertEqual(counters["inductor"]["split_cat_aten_pass"], 0) + self.assertEqual(ref, res, rtol=1e-8, atol=1e-8) + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + @requires_cuda @torch._inductor.config.patch( pre_grad_fusion_options={}, diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 098f69fd863e2..327f96ae34ac7 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -1791,7 +1791,11 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): for cat_node in list(getitem_nodes[0].users.keys()): cat_dim = get_arg_value(cat_node, 1, "dim") cat_inputs = get_arg_value(cat_node, 0, "tensors") - if len(cat_inputs) < threshold_to_cat: + try: + cat_input_len = len(cat_inputs) + except TypeError: + continue + if cat_input_len < threshold_to_cat: continue # check split node and cat node has same dim, and all getitem nodes have same parent node parent_to_indices = defaultdict(list) # type: ignore[var-annotated] From e9367a7a4288e626f01fada3912d68756f1ca6d3 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 16 Jul 2025 17:52:14 -0700 Subject: [PATCH 150/457] ci: Add reusable workflow to get changed files in PRs (#158517) Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/158517 Approved by: https://github.com/huydhn --- .github/workflows/_get-changed-files.yml | 43 ++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 .github/workflows/_get-changed-files.yml diff --git a/.github/workflows/_get-changed-files.yml b/.github/workflows/_get-changed-files.yml new file mode 100644 index 0000000000000..2d3b800f0757b --- /dev/null +++ b/.github/workflows/_get-changed-files.yml @@ -0,0 +1,43 @@ +name: Get Changed Files + +on: + workflow_call: + outputs: + changed-files: + description: "List of changed files (space-separated) or '*' if not in a PR" + value: ${{ jobs.get-changed-files.outputs.changed-files }} + +jobs: + get-changed-files: + runs-on: ubuntu-latest + outputs: + changed-files: ${{ steps.get-files.outputs.changed-files }} + + steps: + - name: Get changed files + id: get-files + env: + GH_TOKEN: ${{ github.token }} + run: | + # Check if we're in a pull request context + if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then + echo "Running in PR context" + + # Get the PR number from the github context + PR_NUMBER="${{ github.event.number }}" + + # Use gh CLI to get changed files in the PR with explicit repo + CHANGED_FILES=$(gh pr view "$PR_NUMBER" --repo "${{ github.repository }}" --json files --jq '.files[].path' | tr '\n' ' ' | sed 's/ $//') + + if [ -z "$CHANGED_FILES" ]; then + echo "No changed files found, setting to '*'" + CHANGED_FILES="*" + fi + + echo "Changed files: $CHANGED_FILES" + echo "changed-files=$CHANGED_FILES" >> "$GITHUB_OUTPUT" + + else + echo "Not in PR context, setting changed files to '*'" + echo "changed-files=*" >> "$GITHUB_OUTPUT" + fi \ No newline at end of file From b6454a9058f2e50be9a3c26c128fec843b09c154 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 17 Jul 2025 00:57:43 +0000 Subject: [PATCH 151/457] [AOT_inductor] model_base.h add Windows include files. (#158477) model_base.h add Windows include files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158477 Approved by: https://github.com/desertfire, https://github.com/jansel --- torch/csrc/inductor/aoti_runtime/model_base.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/csrc/inductor/aoti_runtime/model_base.h b/torch/csrc/inductor/aoti_runtime/model_base.h index 9eac761b7ef89..6e80c90499a0e 100644 --- a/torch/csrc/inductor/aoti_runtime/model_base.h +++ b/torch/csrc/inductor/aoti_runtime/model_base.h @@ -1,9 +1,15 @@ #pragma once +#ifdef _WIN32 +#include +#include // std::function +#else #include -#include #include #include +#endif + +#include #include #include #include From d7e1b8b11d7430c7633dcad6f6596b5df8fa02f7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 16 Jul 2025 13:55:50 -0700 Subject: [PATCH 152/457] [dynamo] Constant fold torch.autograd._profiler_enabled (#158482) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158482 Approved by: https://github.com/williamwen42, https://github.com/StrongerXi --- test/dynamo/test_profiler.py | 41 ++++++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/torch.py | 1 + 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 9a7a892d8b020..860b337e95f75 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -192,6 +192,47 @@ def fn(x, y): ], ) + def test_profiler_enabled(self): + def fn(x): + x = torch.sin(x) + if torch.autograd._profiler_enabled(): + return torch.cos(x) + else: + return torch.sigmoid(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + with torch.autograd.profiler.profile(): + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_profiler_record_function_ignore(self): + def fn(x): + x = torch.sin(x) + if torch.autograd._profiler_enabled(): + with torch.autograd.profiler.record_function("dummy"): + return torch.cos(x) + else: + return torch.sigmoid(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + with torch.autograd.profiler.profile(): + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 88d67822b0b70..0883525f47282 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -176,7 +176,6 @@ "torch.compiler.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_exporting": TorchInGraphFunctionVariable, - "torch.autograd._profiler_enabled": SkipFunctionVariable, "torch._C._to_dlpack": SkipFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, # We graph break on RNG state setters or getters like @@ -2434,6 +2433,7 @@ "torch.atleast_3d", "torch.autograd._calculate_shape", "torch.autograd._is_checkpoint_valid", + "torch.autograd._profiler_enabled", "torch.autograd._make_grads", "torch.autograd._register_py_tensor_class_for_device", "torch.autograd._tensor_or_tensors_to_tuple", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c357e158503c3..72b2e3dc132f4 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -142,6 +142,7 @@ torch.cuda.is_initialized, torch.xpu.current_device, torch.xpu.is_initialized, + torch.autograd._profiler_enabled, ] constant_fold_functions = [ From 2179afd7149c117dace9e552419082094b10a386 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 16 Jul 2025 14:38:05 -0700 Subject: [PATCH 153/457] [easy][guards] Add developer comment for posterity (#158471) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158471 Approved by: https://github.com/StrongerXi --- torch/csrc/dynamo/guards.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index c98119d6adbd3..2b2d09d8b169b 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3879,6 +3879,13 @@ class GetGenericDictGuardAccessor : public GuardAccessor { // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { // borrowed ref + // NOTE for future guard optimization developers - We tried saving the dict + // pointer and weakref of the original object to avoid calling + // PyObject_GenericGetDict on a fast path, but this did not lead any + // meaningful speedups because of 2 reasons + // 1) Once __dict__ is generated, accessing it the second time is fast. + // 2) Getting the object from weakref, from 3.13 onwards, requires + // Py_DECREF, which further eats into the benefit. PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref if (x == nullptr) { // Attribute absent, clear the exception and return false. From c09eba877f9c16908b3a925ef694604c1c761b85 Mon Sep 17 00:00:00 2001 From: Jiapeng Li Date: Thu, 17 Jul 2025 01:27:41 +0000 Subject: [PATCH 154/457] [Device] Add support for PrivateUse1 device type in parse_type function (#157609) This pull request refactors the `parse_type` function in `c10/core/Device.cpp` to improve the handling of the `PrivateUse1` device type. The main change involves reordering the logic to check for the `PrivateUse1` device type earlier in the function for better clarity and efficiency. This help to migrate existed backend to PrivateUse1 smoothly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157609 Approved by: https://github.com/jgong5, https://github.com/albanD --- c10/core/Device.cpp | 6 +- ...t_rename_privateuse1_to_existing_device.py | 59 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 test/test_rename_privateuse1_to_existing_device.py diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index ee51154b420b0..68fa6f91979ab 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -41,6 +41,9 @@ DeviceType parse_type(const std::string& device_string) { "'mkldnn' is no longer used as device type. So torch.device('mkldnn') will be " "deprecated and removed in the future. Please use other valid device types instead."); } + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } auto device = std::find_if( types.begin(), types.end(), @@ -50,9 +53,6 @@ DeviceType parse_type(const std::string& device_string) { if (device != types.end()) { return device->second; } - if (device_string == get_privateuse1_backend()) { - return DeviceType::PrivateUse1; - } std::vector device_names; for (const auto& it : types) { if (it.first) { diff --git a/test/test_rename_privateuse1_to_existing_device.py b/test/test_rename_privateuse1_to_existing_device.py new file mode 100644 index 0000000000000..539412a322385 --- /dev/null +++ b/test/test_rename_privateuse1_to_existing_device.py @@ -0,0 +1,59 @@ +# Owner(s): ["module: PrivateUse1"] + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +class DummyPrivateUse1Module: + @staticmethod + def is_available(): + return True + + @staticmethod + def is_autocast_enabled(): + return True + + @staticmethod + def get_autocast_dtype(): + return torch.float16 + + @staticmethod + def set_autocast_enabled(enable): + pass + + @staticmethod + def set_autocast_dtype(dtype): + pass + + @staticmethod + def get_amp_supported_dtype(): + return [torch.float16] + + +class TestRenamePrivateuseoneToExistingBackend(TestCase): + def test_external_module_register_with_existing_backend(self): + torch.utils.rename_privateuse1_backend("maia") + with self.assertRaisesRegex(RuntimeError, "has already been set"): + torch.utils.rename_privateuse1_backend("dummmy") + + custom_backend_name = torch._C._get_privateuse1_backend_name() + self.assertEqual(custom_backend_name, "maia") + + with self.assertRaises(AttributeError): + torch.maia.is_available() + + with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): + with torch.autocast(device_type=custom_backend_name): + pass + torch._register_device_module("maia", DummyPrivateUse1Module) + + torch.maia.is_available() # type: ignore[attr-defined] + with torch.autocast(device_type=custom_backend_name): + pass + + self.assertEqual(torch._utils._get_device_index("maia:1"), 1) + self.assertEqual(torch._utils._get_device_index(torch.device("maia:2")), 2) + + +if __name__ == "__main__": + run_tests() From f6d138807f138868de0397936e2bee482c1fb987 Mon Sep 17 00:00:00 2001 From: Arsh Zahed Date: Thu, 17 Jul 2025 01:33:47 +0000 Subject: [PATCH 155/457] Always disable ShardingPropagation cache if compiling (#156868) Fixes #151106 Addresses issue (2) in #152963 for the DTensor sharding propagation cache being brittle under compile. The existing `_are_we_tracing` from `distributed._functional_collectives`, which mostly determines if currently tracing based on Fake Tensor dispatch mode, is reused here. **Test Plan**: There are already tests for DTensor + Compile with dynamic shape ([test_dtensor_dynamic](https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/test_dtensor_compile.py#L260), [test_dynamo_dtensor_from_local_dynamic_shapes](https://github.com/pytorch/pytorch/blob/main/test/distributed/tensor/test_dtensor_compile.py#L402)) that cover the change. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156868 Approved by: https://github.com/xmfan --- .../tensor/test_dtensor_compile.py | 29 +++++++++++++++++-- torch/_dynamo/compiled_autograd.py | 11 ++++++- torch/distributed/_functional_collectives.py | 4 +++ torch/distributed/tensor/_op_schema.py | 9 ------ torch/distributed/tensor/_sharding_prop.py | 6 ++-- 5 files changed, 45 insertions(+), 14 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 54ec52ee32d41..86f1e9d8fb479 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -166,6 +166,8 @@ def forward(self, b_buffer, x): return (view_as_1,)""", # noqa: B950 ) + # During tracing, sharding propagation cache is skipped, so an extra dry run for + # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), """\ @@ -173,8 +175,8 @@ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None - add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None - view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None + add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None + view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None return (view_1,)""", # noqa: B950 ) @@ -296,6 +298,29 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfHpu + def test_dtensor_dynamic_cat(self): + # RESET COUNTS + + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in tuple of DTensors as + def fn(x, y): + return ( + torch.cat((x, y), dim=0) + .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) + .to_local()[0] + ) + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + y = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x, y) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x, y) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e52fb5026cb98..bda2494e7a9f5 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -304,11 +304,13 @@ def begin_capture( accumulate_grad: bool, check_nans: bool, ): + global in_compiled_autograd_initial_trace counters["compiled_autograd"]["captures"] += 1 self.id = next(COMPILE_COUNTER) self.aot_id_counter: dict[int, int] = defaultdict(int) self.compile_context = make_compile_context(self.id) self.compile_context.__enter__() + in_compiled_autograd_initial_trace = True self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None self.start_time_ns = time.time_ns() get_chromium_event_logger().log_event_start( @@ -969,6 +971,8 @@ def create_graph_module(self, id): return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) def end_capture(self, outputs): + global in_compiled_autograd_initial_trace + self.fx_tracer.create_proxy( "call_function", FakeCompiledAutogradEngine._exec_final_callbacks_stub, @@ -1085,6 +1089,7 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): log_pt2_compile_event=True, ) self.compile_context.__exit__(None, None, None) + in_compiled_autograd_initial_trace = False return runtime_wrapper, self.compiler_fn(graph) @staticmethod @@ -1394,6 +1399,9 @@ def set_node_origin( # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" compiled_autograd_enabled_force_eager = False +# global flag to check if we are capturing for compiled autograd +in_compiled_autograd_initial_trace = False + # global flag to check if we are processing graphs produced from a compiled autograd graph in_compiled_autograd_region = False @@ -1498,12 +1506,13 @@ def _disable(): # return to starting state of a new process def reset() -> None: - global compiled_autograd_enabled + global compiled_autograd_enabled, in_compiled_autograd_initial_trace compiled_autograd_enabled = False assert not in_compiled_autograd_region torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) torch._C._dynamo.compiled_autograd.set_verbose_logger(None) torch._C._dynamo.compiled_autograd.clear_cache() + in_compiled_autograd_initial_trace = False global COMPILE_COUNTER COMPILE_COUNTER = itertools.count() diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index ec51b2b7a1817..0ffae8a9c9fe3 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -821,6 +821,10 @@ def _are_we_tracing() -> bool: # If fake mode is turned on, we are almost definitely compiling/tracing. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: return True + + if torch._dynamo.compiled_autograd.in_compiled_autograd_initial_trace: + return True + return get_proxy_mode() is not None diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 54d85aa1b3abe..c359f28eb3efc 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -317,15 +317,6 @@ def __str__(self) -> str: args_schema.append(str(arg)) return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" - def __post_init__(self) -> None: - has_symints = False - for a in self.args_schema: - if isinstance(a, DTensorSpec) and a.tensor_meta is not None: - if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): - has_symints = True - break - self.has_symints = has_symints - def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: arg = self.args_schema[arg_idx] is_tensor = isinstance(arg, DTensorSpec) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index b05c84de01887..32ab4943b5f05 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -8,6 +8,7 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, @@ -300,8 +301,9 @@ def propagate(self, op_info: OpInfo) -> None: # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, - # and tracing does not need to be as fast as eagermode DTensor usages. - if op_info.schema.has_symints: + # and compile autograd initial tracing, which do not need to be as fast as + # eagermode DTensor usages. + if _are_we_tracing(): output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: output_sharding = cast( From 1179e333237b02ed8fe2ba10cb9a23adf98d7d7a Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 11 Jul 2025 00:15:04 +0000 Subject: [PATCH 156/457] Add DeviceAllocator as the base device allocator (#138222) # Motivation In line with [RFC] [A device-agnostic Python device memory related API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/134978), some memory-related APIs are widely used in popular repositories, such as HuggingFace [so many if-else conditional code](https://github.com/search?q=repo%3Ahuggingface%2Faccelerate%20torch.cuda.empty_cache&type=code). We would like to introduce a generic API set under torch.accelerator namespace to generalize these user cases.
Device-specific memory APIs torch.xxx.foo Device-agnostic memory APIs torch.accelerator.foo
```python torch.xxx.empty_cache ``` ```python torch.accelerator.empty_cache ```
```python torch.xxx.reset_peak_memory_stats ``` ```python torch.accelerator.reset_peak_memory_stats ```
```python torch.xxx.reset_accumulated_memory_stats ``` ```python torch.accelerator.reset_accumulated_memory_stats ```
```python torch.xxx.memory_stats ``` ```python torch.accelerator.memory_stats ```
```python torch.xxx.memory_allocated ``` ```python torch.accelerator.memory_allocated ```
```python torch.xxx.max_memory_allocated ``` ```python torch.accelerator.max_memory_allocated ```
```python torch.xxx.memory_reserved ``` ```python torch.accelerator.memory_reserved ```
```python torch.xxx.max_memory_reserved ``` ```python torch.accelerator.max_memory_reserved ```
# Solution This design follows a similar pattern to `HostAllocator`. We're introducing a base class `DeviceAllocator`, from which `CUDAAllocator` and `XPUAllocator` will inherit. This allows us to provide a unified call path like: `torch.accelerator.empty_cache()` -> `GetDeviceAllocator(allocator)->empty_cache()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138222 Approved by: https://github.com/albanD, https://github.com/Camyll --- aten/src/ATen/cuda/CUDAGraph.cpp | 1 - aten/src/ATen/cuda/CUDAGraph.h | 1 + .../hip/impl/HIPAllocatorMasqueradingAsCUDA.h | 26 +++++++-- .../HIPCachingAllocatorMasqueradingAsCUDA.cpp | 7 ++- c10/core/CachingDeviceAllocator.cpp | 10 ++++ c10/core/CachingDeviceAllocator.h | 53 +++++++++++++++++++ c10/cuda/CUDACachingAllocator.cpp | 1 + c10/cuda/CUDACachingAllocator.h | 19 ++++--- c10/cuda/CUDAGraphsC10Utils.h | 6 --- c10/xpu/XPUCachingAllocator.cpp | 19 ++++--- 10 files changed, 116 insertions(+), 27 deletions(-) create mode 100644 c10/core/CachingDeviceAllocator.cpp diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 7fba7c4c7424c..2800e505a9b76 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index c8cae16b624fe..4f2aa31dd1c35 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h index 39ab441478e8f..c1ecea34db16f 100644 --- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include // Use of c10::hip namespace here makes hipification easier, because @@ -10,10 +10,10 @@ namespace c10::hip { // Takes a valid HIPAllocator (of any sort) and turns it into // an allocator pretending to be a CUDA allocator. See // Note [Masquerading as CUDA] -class HIPAllocatorMasqueradingAsCUDA final : public Allocator { - Allocator* allocator_; +class HIPAllocatorMasqueradingAsCUDA final : public DeviceAllocator { + DeviceAllocator* allocator_; public: - explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) + explicit HIPAllocatorMasqueradingAsCUDA(DeviceAllocator* allocator) : allocator_(allocator) {} DataPtr allocate(size_t size) override { DataPtr r = allocator_->allocate(size); @@ -26,6 +26,24 @@ class HIPAllocatorMasqueradingAsCUDA final : public Allocator { void copy_data(void* dest, const void* src, std::size_t count) const final { allocator_->copy_data(dest, src, count); } + bool initialized() override { + return allocator_->initialized(); + } + void emptyCache(MempoolId_t mempool_id = {0, 0}) { + allocator_->emptyCache(mempool_id); + } + void recordStream(const DataPtr& ptr, c10::Stream stream) { + allocator_->recordStream(ptr, stream); + } + CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) { + return allocator_->getDeviceStats(device); + } + void resetAccumulatedStats(c10::DeviceIndex device) { + allocator_->resetAccumulatedStats(device); + } + void resetPeakStats(c10::DeviceIndex device) { + allocator_->resetPeakStats(device); + } }; } // namespace c10::hip diff --git a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp index 46f7d247293a1..19bc0a6b34e54 100644 --- a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp +++ b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp @@ -4,8 +4,9 @@ namespace c10 { namespace hip { namespace HIPCachingAllocatorMasqueradingAsCUDA { +static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get()); + Allocator* get() { - static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get()); return &allocator; } @@ -13,5 +14,9 @@ void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsC HIPCachingAllocator::recordStream(ptr, stream.hip_stream()); } +// Register this HIP allocator as CUDA allocator to enable access through both +// c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) APIs +REGISTER_ALLOCATOR(kCUDA, &allocator) + } // namespace HIPCachingAllocatorMasqueradingAsCUDA }} // namespace c10::hip diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp new file mode 100644 index 0000000000000..582efd59cf1b1 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.cpp @@ -0,0 +1,10 @@ +#include + +namespace c10 { + +// Ensures proper DLL export of this pure virtual base class on Windows, +// since it's mainly used in other DLLs outside c10.dll. +DeviceAllocator::DeviceAllocator() = default; +DeviceAllocator::~DeviceAllocator() = default; + +} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index b23490de693a8..0bec03ae417fa 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace c10::CachingDeviceAllocator { @@ -59,3 +60,55 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator + +namespace c10 { + +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by Graph mode capture_begin. +// second is set if the instance is created by Graph mode graph_pool_handle. +using MempoolId_t = std::pair; + +struct C10_API DeviceAllocator : public c10::Allocator { + DeviceAllocator(); + ~DeviceAllocator() override; + + // Returns true if the allocator has been properly initialized and is ready + // for use + virtual bool initialized() = 0; + + // Releases all cached device memory from the specified memory pool back to + // the system + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; + + // Associates a memory allocation with a stream to establish dependency + // tracking. Prevents memory reuse until all operations on the specified + // stream complete + virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; + + // Retrieves comprehensive memory statistics for the specified device, + // including allocation patterns, usage metrics + virtual CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + + // Resets cumulative allocation statistics for the specified device to zero + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + + // Resets peak memory usage statistics for the specified device + virtual void resetPeakStats(c10::DeviceIndex device) = 0; +}; + +// This function is used to get the DeviceAllocator for a specific device type +// and keep backward compatibility with c10::GetAllocator. +C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { + TORCH_CHECK( + t != DeviceType::CPU, + "getDeviceAllocator is not supported for CPU device type."); + auto* allocator = c10::GetAllocator(t); + auto* device_allocator = dynamic_cast(allocator); + TORCH_INTERNAL_ASSERT( + device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); + return device_allocator; +} + +} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4d58c11c5c9bc..91ea6d9d9bd4d 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4179,6 +4179,7 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); + at::SetAllocator(kCUDA, r, 0); allocator.store(r); } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index a6fa61110d675..5e412342b17d0 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,25 +202,24 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public Allocator { +class CUDAAllocator : public DeviceAllocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; - virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; - virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - virtual void resetPeakStats(c10::DeviceIndex device) = 0; + // Keep for BC only + virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; + void recordStream(const DataPtr& ptr, c10::Stream stream) override { + CUDAStream cuda_stream = CUDAStream(stream); + recordStream(ptr, cuda_stream); + } virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -525,6 +524,10 @@ inline void enablePeerAccess( namespace c10::cuda { +// Keep BC only +using c10::CaptureId_t; +using c10::MempoolId_t; + // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index eb29ca8bc9f02..936875fd71d5c 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,12 +9,6 @@ namespace c10::cuda { -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by CUDAGraph::capture_begin. -// second is set if the instance is created by at::cuda::graph_pool_handle. -using MempoolId_t = std::pair; - // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 543b48f081135..a5e088515ff55 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -540,7 +540,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public Allocator { +class XPUAllocator : public DeviceAllocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -576,6 +576,10 @@ class XPUAllocator : public Allocator { } } + bool initialized() override { + return !device_allocators.empty(); + } + void malloc( void** devPtr, DeviceIndex device, @@ -610,13 +614,13 @@ class XPUAllocator : public Allocator { } } - void emptyCache() { + void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, XPUStream stream) { + void recordStream(const DataPtr& ptr, c10::Stream stream) override { if (!ptr.get()) { return; } @@ -626,7 +630,8 @@ class XPUAllocator : public Allocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - device_allocators[block->device]->recordStream(block, stream); + c10::xpu::XPUStream xpu_stream{stream}; + device_allocators[block->device]->recordStream(block, xpu_stream); } DataPtr allocate(size_t size) override { @@ -679,17 +684,17 @@ class XPUAllocator : public Allocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) { + DeviceStats getDeviceStats(DeviceIndex device) override { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) { + void resetPeakStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) { + void resetAccumulatedStats(DeviceIndex device) override { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } From 2ad5c25cfc603c3656e6699d6137419dbb009495 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 11 Jul 2025 00:15:05 +0000 Subject: [PATCH 157/457] Add unified memory APIs for torch.accelerator (#152932) # Motivation The following API will be put under torch.accelerator - empty_cache - max_memory_allocated - max_memory_reserved - memory_allocated - memory_reserved - memory_stats - reset_accumulated_memory_stats - reset_peak_memory_stats Pull Request resolved: https://github.com/pytorch/pytorch/pull/152932 Approved by: https://github.com/albanD ghstack dependencies: #138222 --- aten/src/ATen/DeviceAccelerator.h | 22 ++++ docs/source/accelerator.md | 23 ++++ torch/_C/__init__.pyi.in | 5 + torch/accelerator/__init__.py | 18 +++ torch/accelerator/memory.py | 201 ++++++++++++++++++++++++++++++ torch/csrc/DeviceAccelerator.cpp | 64 ++++++++++ torch/cuda/memory.py | 4 +- 7 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 torch/accelerator/memory.py diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f37e492c861fe..f23b35047fcc8 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c6f2fb1080400..ce593a9acf518 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,3 +25,26 @@ synchronize device_index ``` + +```{eval-rst} +.. automodule:: torch.accelerator.memory +``` +```{eval-rst} +.. currentmodule:: torch.accelerator.memory +``` + +## Memory management +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + reset_accumulated_memory_stats + reset_peak_memory_stats +``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index a5c4d390ee36d..1a785ef8f237a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2420,6 +2420,11 @@ def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... +def _accelerator_isAllocatorInitialized() -> _bool: ... +def _accelerator_emptyCache() -> None: ... +def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... +def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... +def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index e9e48f1cf3061..4d1a78df1f74c 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,6 +8,16 @@ import torch from ._utils import _device_t, _get_device_index +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) __all__ = [ @@ -15,9 +25,17 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", + "empty_cache", "device_count", "device_index", "is_available", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py new file mode 100644 index 0000000000000..d34a11a3a02e5 --- /dev/null +++ b/torch/accelerator/memory.py @@ -0,0 +1,201 @@ +from collections import OrderedDict +from typing import Any + +import torch + +from ._utils import _device_t, _get_device_index + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other application. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return + torch._C._accelerator_emptyCache() + + +def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: + r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of allocation requests received by the memory allocator. + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of reserved segments from device memory allocation. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of active memory blocks. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + number of inactive, non-releasable memory blocks. + - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of inactive, non-releasable memory. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool + (as of June 2025, for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool + (as of June 2025, for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + In addition to the core statistics, we also provide some simple event + counters: + + - ``"num_alloc_retries"``: number of failed device memory allocation calls that + result in a cache flush and retry. + - ``"num_ooms"``: number of out-of-memory errors thrown. + - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. + - ``"num_device_alloc"``: number of device memory allocation calls. + - ``"num_device_free"``: number of device memory free calls. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + if not torch._C._accelerator_isAllocatorInitialized(): + return OrderedDict() + device_index = _get_device_index(device_index, optional=True) + stats = torch._C._accelerator_getDeviceStats(device_index) + flat_stats = [] + + def flatten(prefix: str, value: Any) -> None: + if isinstance(value, dict): + for k, v in value.items(): + nested_prefix = f"{prefix}.{k}" if prefix else k + flatten(nested_prefix, v) + else: + flat_stats.append((prefix, value)) + + flatten("", stats) + flat_stats.sort() + return OrderedDict(flat_stats) + + +def memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory occupied by tensors + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors + in bytes for a given device index. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` device memory managed by the caching allocator + in bytes for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device_index: _device_t = None, /) -> int: + r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator + in bytes for a given device index. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + """ + return memory_stats(device_index).get("reserved_bytes.all.peak", 0) + + +def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetAccumulatedStats(device_index) + + +def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: + r"""Reset the "peak" stats tracked by the current :ref:`accelerator` + memory allocator for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 37fac325d3167..dc3da8881a715 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -72,6 +72,70 @@ void initModule(PyObject* module) { torch::utils::maybe_initialize_device(device_type); return at::accelerator::maybeExchangeDevice(device_index); }); + + m.def("_accelerator_isAllocatorInitialized", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->initialized(); + }); + + m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { + using c10::CachingAllocator::Stat; + using c10::CachingAllocator::StatArray; + using c10::CachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + + const auto stats = at::accelerator::getDeviceStats(device_index); + const auto stat_to_dict = [](const Stat& stat) -> py::dict { + py::dict dict; + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { + const std::array(StatType::NUM_TYPES)> + kStatTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(kStatTypeNames.size())) { + dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); + } + return dict; + }; + + py::dict result; + result["num_alloc_retries"] = stats.num_alloc_retries; + result["num_ooms"] = stats.num_ooms; + result["max_split_size"] = stats.max_split_size; + result["num_sync_all_streams"] = stats.num_sync_all_streams; + result["num_device_alloc"] = stats.num_device_alloc; + result["num_device_free"] = stats.num_device_free; + result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); + result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); + result["active_bytes"] = stat_array_to_dict(stats.active_bytes); + result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); + result["allocation"] = stat_array_to_dict(stats.allocation); + result["segment"] = stat_array_to_dict(stats.segment); + result["active"] = stat_array_to_dict(stats.active); + result["inactive_split"] = stat_array_to_dict(stats.inactive_split); + result["inactive_split_bytes"] = + stat_array_to_dict(stats.inactive_split_bytes); + result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); + result["oversize_segments"] = stat_to_dict(stats.oversize_segments); + return result; + }); + + m.def( + "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetAccumulatedStats(device_index); + }); + + m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { + at::accelerator::resetPeakStats(device_index); + }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 3a2e1bb0f8909..08bfced67532b 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of October 2019, for size >= 1MB allocations). + (as of June 2025, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of October 2019, for size < 1MB allocations). + (as of June 2025, for size < 1MB allocations). Metric type: From 1839e8d04b81ee6eda0cff6fbfc218a7a600f6f7 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 11 Jul 2025 12:30:47 -0700 Subject: [PATCH 158/457] [DTensor] Assert DTensorSpec has valid placements (#158133) This helped identify buggy sharding rules during debugging, why not check it in. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158133 Approved by: https://github.com/XilunWu, https://github.com/zpcore ghstack dependencies: #158132 --- test/distributed/tensor/test_dtensor_compile.py | 2 +- torch/distributed/tensor/_dtensor_spec.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 86f1e9d8fb479..5041a0d6de54d 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -343,7 +343,7 @@ def test_dtensor_constructor_w_graph_break(self): x = torch.randn(64, 32, requires_grad=True) spec = DTensorSpec( mesh, - (Replicate(), Shard(0)), + (Replicate(),), tensor_meta=TensorMeta( shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype ), diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 360f1a0ea0168..48739db536a9b 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -32,6 +32,10 @@ class DTensorSpec: def __post_init__(self) -> None: if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) + if not len(self.placements) == self.mesh.ndim: + raise ValueError( + f"DTensorSpec requires one placement per mesh dim (mesh.ndim={self.mesh.ndim}), got {self.placements=}" + ) self._hash: Optional[int] = None def __setattr__(self, attr: str, value: Any) -> None: From 24b49b98810bb77f3cfa4c15baa9a15c9be3db61 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Thu, 17 Jul 2025 03:36:47 +0000 Subject: [PATCH 159/457] =?UTF-8?q?[Fix]=20Rework=20CUDA=20error=20explana?= =?UTF-8?q?tion=20framework=20to=20be=20less=20destructive=20=E2=80=A6=20(?= =?UTF-8?q?#158484)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …in fbsource Fix-forward for #158395 Added `std::string c10::cuda::get_cuda_error_help(const char* error_string)` to provide a framework for appending clarifying messages to CUDA errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158484 Approved by: https://github.com/aorenste --- c10/cuda/CUDAException.cpp | 3 ++- c10/cuda/CUDAMiscFunctions.cpp | 27 +++++++++++++++------------ c10/cuda/CUDAMiscFunctions.h | 3 ++- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 40cacff550976..5eb54b2454539 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -30,7 +30,8 @@ void c10_cuda_check_implementation( check_message.append("CUDA error: "); const char* error_string = cudaGetErrorString(cuda_error); check_message.append(error_string); - check_message.append(c10::cuda::get_cuda_check_suffix(error_string)); + check_message.append(c10::cuda::get_cuda_error_help(error_string)); + check_message.append(c10::cuda::get_cuda_check_suffix()); check_message.append("\n"); if (include_device_assertions) { check_message.append(c10_retrieve_device_side_assertion_info()); diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index 5de9996d2eb76..170d53398195f 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,31 +1,34 @@ #include #include #include -#include #include namespace c10::cuda { +// Explain common CUDA errors // NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) -std::string get_cuda_check_suffix(const char* error_string) noexcept { - std::string suffix; - - // Explain common CUDA errors +std::string get_cuda_error_help(const char* error_string) noexcept { + std::string help_text; if (strstr(error_string, "invalid device ordinal")) { - suffix.append("\nGPU device may be out of range, do you have enough GPUs?"); + help_text.append( + "\nGPU device may be out of range, do you have enough GPUs?"); } + return help_text; +} +// NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) +const char* get_cuda_check_suffix() noexcept { static auto device_blocking_flag = c10::utils::check_env("CUDA_LAUNCH_BLOCKING"); static bool blocking_enabled = (device_blocking_flag.has_value() && device_blocking_flag.value()); - if (!blocking_enabled) { - suffix.append( - "\nCUDA kernel errors might be asynchronously reported at some" - " other API call, so the stacktrace below might be incorrect." - "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"); + if (blocking_enabled) { + return ""; + } else { + return "\nCUDA kernel errors might be asynchronously reported at some" + " other API call, so the stacktrace below might be incorrect." + "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"; } - return suffix; } std::mutex* getFreeMutex() { static std::mutex cuda_free_mutex; diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index c79a22bea231d..26a15d85a61e2 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -8,6 +8,7 @@ #include namespace c10::cuda { -C10_CUDA_API std::string get_cuda_check_suffix(const char*) noexcept; +C10_CUDA_API std::string get_cuda_error_help(const char*) noexcept; +C10_CUDA_API const char* get_cuda_check_suffix() noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda From ebf83b8b7772632c0558db9a88281ee10ff2df38 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 17 Jul 2025 04:19:02 +0000 Subject: [PATCH 160/457] [audio hash update] update the pinned audio hash (#158402) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158402 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 7a71e6f2a5e43..a2d5ddd38cec7 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -6c57850358f34c47802db216b0746e4e9d08a95a +00b0c91db92c51a11356249262577b9fa26c18c5 From 8eaa9f2701277f328d9d6aea1bfe7cba20792f7c Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 17 Jul 2025 04:21:43 +0000 Subject: [PATCH 161/457] Fix mask construction when dispatching index_put to masked_fill (#158472) Fixes #158413 Previously trailing Nones in the index were incorrectly handled as implicit broadcasting dims in the mask, whereas they should just be ignored. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158472 Approved by: https://github.com/ezyang --- aten/src/ATen/native/TensorAdvancedIndexingUtils.h | 4 +++- test/test_indexing.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 0a200f157d511..bc6c2533eac5c 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -35,7 +35,9 @@ inline std::tuple canDispatchToMaskedFill( auto self_device = self.device(); for (const std::optional& i : indices) { if (!i.has_value() || !(*i).defined()) { - num_ind++; + if (!mask.defined()) { + num_ind++; + } } else { const Tensor& index = *i; if ((index.scalar_type() != kByte && index.scalar_type() != kBool) || diff --git a/test/test_indexing.py b/test/test_indexing.py index a57f658025a3b..58854a995db6f 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -964,6 +964,9 @@ def test_multi_dimensional_bool_mask_assignment(self, device): mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool, device=device) v[:, mask, :] = 0 self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) + v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device) + torch.ops.aten.index_put_(v, [None, mask, None], torch.tensor(0)) + self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) def test_byte_mask(self, device): v = torch.randn(5, 7, 3, device=device) From 415dfabe9b569b71098a2f874f3fc67ad2a4fc2e Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 16 Jul 2025 22:24:33 +0800 Subject: [PATCH 162/457] [Easy] Fix the format (#158450) When I modify the code located in test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg, some unrelated format error occurred. ```Python Lint for torch/_inductor/fx_passes/fuse_attention.py: Error (CODESPELL) spelling error Failed due to ValueError: /pytorch/pytorch/torch/_inductor/fx_passes/fuse_attention.py:587: differnt ==> different Please either fix the error or add the word(s) to the dictionary file. HINT: all-lowercase words in the dictionary can cover all case variations. Lint for torch/fx/traceback.py: Error (MYPY) [assignment] Incompatible types in assignment (expression has type "str", variable has type "None") 101 | 102 | def _get_action_string(self): 103 | if self._action_string is None: 104 | self._action_string = "+".join([a.name.lower() for a in self.action]) 105 | return self._action_string 106 | 107 | def print_readable(self, indent=0): Error (MYPY) [assignment] Incompatible types in assignment (expression has type "dict[str, Any]", variable has type "None") 121 | if self._dict is None: 122 | # Convert the object to a dictionary 123 | action_string = self._get_action_string() 124 | self._dict = { 125 | "name": self.name, 126 | "target": self.target, 127 | "graph_id": self.graph_id, Error (MYPY) [return-value] Incompatible return value type (got "None", expected "dict[Any, Any]") 130 | "from_node": [node.to_dict() for node in self.from_node], 131 | } 132 | 133 | return self._dict 134 | 135 | def __eq__(self, other: object): 136 | if not isinstance(other, NodeSource): ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158450 Approved by: https://github.com/Skylion007 --- torch/fx/traceback.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 836b41d661859..59187fedccfaa 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -83,8 +83,8 @@ def __init__( self.from_node = [] # cache the action string and dict representation for performance. - self._action_string = None - self._dict = None + self._action_string: Optional[str] = None + self._dict: Optional[dict[str, Any]] = None @property def name(self) -> str: @@ -132,6 +132,7 @@ def to_dict(self) -> dict: "from_node": [node.to_dict() for node in self.from_node], } + assert self._dict is not None return self._dict def __eq__(self, other: object): From 79d7c754ab8ae0e5c3a614521632d2cfbfa0fdba Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 10 Jul 2025 20:28:30 -0700 Subject: [PATCH 163/457] DDE-Free select with unbacked index. (#157605) When select has data dependent input, we cant tell if the actual index shall be index+size or index. to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the output view and we compute its value dynamically at runtime when inductor is lowered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605 Approved by: https://github.com/ColinPeppler --- test/export/test_export.py | 22 ++++++ test/test_dynamic_shapes.py | 82 +++++++++++++++++++++ torch/_export/passes/_node_metadata_hook.py | 1 + torch/_inductor/codegen/cpp_wrapper_cpu.py | 14 ++++ torch/_inductor/codegen/wrapper.py | 8 ++ torch/_inductor/dependencies.py | 35 +++++++++ torch/_inductor/graph.py | 4 +- torch/_inductor/ir.py | 61 +++++++++++++-- torch/_inductor/lowering.py | 69 ++++++++++++++--- torch/_inductor/scheduler.py | 3 +- torch/_inductor/utils.py | 16 +++- torch/_meta_registrations.py | 33 --------- torch/_subclasses/fake_impls.py | 42 +++++++++++ torch/fx/experimental/symbolic_shapes.py | 2 + torch/fx/passes/runtime_assert.py | 10 +++ 15 files changed, 349 insertions(+), 53 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index b772667de105e..1c0279d565268 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15782,6 +15782,28 @@ def forward(self, x, mask): ignore_empty_lines=True, ) + def test_unbacked_select_index(self): + class MyModel(torch.nn.Module): + def forward(self, x, y): + u0 = y.item() + return x.select(0, u0) + + example_inputs = ( + torch.randn((3, 3), dtype=torch.bfloat16), + torch.tensor([0]), + ) + + traced = export(MyModel(), example_inputs).run_decompositions({}) + self.assertExpectedInline( + traced.graph_module.code, + """\ +def forward(self, x, y): + item = torch.ops.aten.item.default(y); y = None + select = torch.ops.aten.select.int(x, 0, item); x = item = None + return (select,)""", + ignore_empty_lines=True, + ) + if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0f299cd6b6c79..59c08f71671e0 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3529,6 +3529,88 @@ def func(x): ignore_empty_lines=True, ) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_select_index(self): + cnt = CompileCounterWithBackend("inductor") + + def func(x, y): + u0 = y.item() + return ( + torch.select(x, 0, u0), + torch.select(x, 1, u0), + torch.select(x, 2, u0), + ) + + compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) + x = torch.rand(3, 3, 3) + zero = torch.tensor([0]) + pos = torch.tensor([1]) + # code can handle both negative and positive indices. + neg = torch.tensor([-1]) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + self.assertEqual(compiled_func(x, zero), func(x, zero)) + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + self.assertExpectedInline( + output, + """\ + _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None + select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) + select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) + select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None + return (select, select_1, select_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + self.assertEqual(compiled_func(x, pos), func(x, pos)) + self.assertEqual(compiled_func(x, neg), func(x, neg)) + self.assertEqual(cnt.frame_count, 1) + + def func2(x, y): + u0, u1 = y.tolist() + return torch.select(x, 0, u0 + u1) + + compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)( + func2 + ) + zero = torch.tensor([0, 0]) + pos = torch.tensor([1, 1]) + neg = torch.tensor([-1, -1]) + + self.assertEqual(compiled_func2(x, pos), func2(x, pos)) + self.assertEqual(compiled_func2(x, neg), func2(x, neg)) + self.assertEqual(compiled_func2(x, zero), func2(x, zero)) + self.assertEqual(cnt.frame_count, 2) + + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_select_index_with_check(self): + def func3(x, y): + u0 = y.item() + # Test that taking the non-unbacked path works fine also. + torch._check(u0 >= 0) + return (torch.select(x, 1, u0),) + + compiled_func3 = torch.compile( + fullgraph=True, backend="inductor", dynamic=True + )(func3) + x = torch.rand(3, 3, 3) + zero = torch.tensor([0]) + pos = torch.tensor([1]) + print(compiled_func3(x, pos)) + + self.assertEqual(compiled_func3(x, pos), func3(x, pos)) + self.assertEqual(compiled_func3(x, zero), func3(x, zero)) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_unbacked_select_index_cpp_wrapper(self): + self.test_unbacked_select_index() + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index 41005e5009738..b1195cf421288 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -54,6 +54,7 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) ) }, ) + node.meta["torch_fn"] = ( f"{node.target.__name__}_0", f"{node.target.__class__.__name__}.{node.target.__name__}", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index cbca6d9fe5d28..8a7f1b2aaa028 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1447,6 +1447,20 @@ def codegen_dynamic_scalar(self, node): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) + def codegen_dynamic_select_index(self, node): + index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) + + index_compute_str = ( + f"{index_cpp_str} < 0 ? {index_cpp_str} + " + f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" + ) + self.writeline( + f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " + f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) + def make_buffer_free(self, buffer): return ( "" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0b8ba86c3c185..e601cbb8ed894 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1802,6 +1802,14 @@ def codegen_multi_output(self, node: ir.MultiOutput): arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + def codegen_dynamic_select_index(self, node): + index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" + self.writeline( + f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" + ) + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) + def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 9de52061c6489..f948a7a534c8f 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -11,6 +11,7 @@ import sympy import torch +from torch._inductor.utils import get_free_symbols from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -38,6 +39,12 @@ class Dep(abc.ABC): name: str index: sympy.Expr + @abc.abstractmethod + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + pass + @abc.abstractmethod def rename(self, renames: dict[str, str]) -> Self: pass @@ -70,6 +77,15 @@ class MemoryDep(Dep): size: tuple[sympy.Expr, ...] mode: Optional[str] = None + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.index, unbacked_only) + | get_free_symbols(self.size, unbacked_only) + | get_free_symbols(self.var_names, unbacked_only) + ) + def __repr__(self) -> str: maybe_mode = "" if self.mode is not None: @@ -307,6 +323,11 @@ def rename(self, renames: dict[str, str]) -> "StarDep": return StarDep(renames[self.name], self.mode) return self + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + def numbytes_hint(self) -> int: try: return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( @@ -343,6 +364,11 @@ class WeakDep(Dep): # Buffer that is doing the mutation mutating_buf: str + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + @property def index(self) -> sympy.Expr: raise NotImplementedError("WeakDep does not have an index") @@ -440,6 +466,15 @@ def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]: names.add(dep.name) return names + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + result: OrderedSet[sympy.Symbol] = OrderedSet() + + for dep in self.reads_and_writes(): + result |= dep.get_free_symbol_uses(unbacked_only) + return result + class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ac299d5b0c2d0..660b01b69233b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -341,6 +341,7 @@ def __init__( shape_env.deferred_runtime_asserts.copy() ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() + self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} @@ -1821,7 +1822,7 @@ def debug(msg: str) -> None: shape_env = V.graph.sizevars.shape_env - # An input can an unbacked symint i.e.: when mark_unabcked is used. + # An input can be unbacked symint i.e.: when mark_unabcked is used. # in that case add it to new_unbacked_defs. if ( n.op == "placeholder" @@ -1888,6 +1889,7 @@ def format_new_defs() -> str: V.fake_mode.shape_env.unbacked_renamings.get(s, s) for s in unbacked_bindings.keys() ) + assert new_unbacked_defs >= renamed_unbacked_bindings, ( f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" f"fx node is: {n.format_node()}\n" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d6dd82aa52f2d..25f57a503dfaa 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -49,6 +49,7 @@ from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._inductor import metrics +from torch._inductor.utils import get_free_symbols from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -62,7 +63,6 @@ compute_unbacked_bindings, free_symbols, free_unbacked_symbols, - IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -304,13 +304,6 @@ def reindex(index: Sequence[_T]) -> Sequence[_V]: return reindex -def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: - if unbacked_only: - return free_unbacked_symbols(x) - else: - return free_symbols(x) - - NHWC_STRIDE_ORDER = [3, 0, 2, 1] NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] @@ -4329,6 +4322,13 @@ def get_read_names(self) -> OrderedSet[str]: return self.data.get_read_names() def get_read_writes(self) -> dependencies.ReadWrites: + if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): + return dependencies.ReadWrites( + reads=OrderedSet(), + writes=OrderedSet(), + index_exprs=OrderedSet(), + ) + with patch.object(FlexibleLayout, "allow_indexing", True): if self.data.get_reduction_type(): return extract_read_writes( @@ -4367,6 +4367,7 @@ def get_free_symbol_uses( | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) | self.data.get_free_symbol_uses(unbacked_only) + | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -6975,6 +6976,50 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) +class DynamicSelectStorageOffset(ExternKernel): + """ + The result of computing a dynamic selection index is determined as follows: when the index in the + select operation is unbacked, the actual index calculation is ambiguous for negative indices + (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked + SymInt to represent the storage offset and decompose the select operation into a call to as_strided, + computing the storage offset at runtime with this node. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__( + self, + unbacked_offset_symbol: sympy.Symbol, + index: sympy.Symbol, + base_offset: Union[sympy.Symbol, int], + base_dim_stride: Union[sympy.Symbol, int], + size: Union[sympy.Symbol, int], + ) -> None: + super().__init__(None, NoneLayout(device=torch.device("cpu")), []) + # This node codegen the following: + # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size) + self.unbacked_offset_symbol = unbacked_offset_symbol + self.index = index + self.base_offset = base_offset + self.base_dim_stride = base_dim_stride + self.size = size + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.unbacked_offset_symbol]) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.index, unbacked_only) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_dynamic_select_index(self) + + class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index c4c8f70003c60..f6b08499e4d5c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,7 +40,11 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_unbacked_symbols, + resolve_unbacked_bindings, +) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -990,10 +994,7 @@ def squeeze(x, dim=None): new_shape = [] for d, s in enumerate(x.get_size()): - if not ( - d in dims - and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) - ): + if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): new_shape.append(s) # squeeze does nothing if the size isn't 1 @@ -1759,8 +1760,60 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): @register_lowering(aten.select, type_promotion_kind=None) def select(x, dim, idx): - idx = View.handle_negative_index(idx, x.get_size()[dim]) - return squeeze(slice_(x, dim, idx, idx + 1), dim) + idx = sympy.expand(idx) + size = sympy.expand(x.get_size()[dim]) + actual_index = None + + if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)): + actual_index = idx + size + elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)): + actual_index = idx + + if actual_index is not None: + if has_free_unbacked_symbols(idx): + # Inductor could generate incorrect views for tensors with unbacked symbols here; + # Squeeze operations are translated to views, resulting in incorrect strides. + # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, + # we use as_strided instead. + # Removing this branch will cause test_unbacked_select_index_with_check to fail. + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) + else: + slice_result = slice_(x, dim, actual_index, actual_index + 1) + return squeeze(slice_result, dim) + + # Unbacked Semantics: + # When the index idx is unbacked (e.g., u0), we compute the index dynamically + # during the lowering of the select operation using DynamicSelectStorageOffset. + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) + + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = unbacked_offset_sym + buffer = ir.DynamicSelectStorageOffset( + unbacked_offset_sym, + idx, + x.get_layout().offset, + new_stride[dim], + x.get_size()[dim], + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) @register_lowering(aten.split, type_promotion_kind=None) @@ -3086,8 +3139,6 @@ def long_tensor(data): @register_lowering(aten._local_scalar_dense) def _local_scalar_dense(data): - from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings - # This is interesting! Most lowerings return tensors, so you can just # return the buffer you allocated and it will get used (or not used, if # it's dead.) But _local_scalar_dense (aka item) returns an int, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 34f15869085f0..a4507990400fd 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2130,9 +2130,11 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.logged_slow_fusion = OrderedSet[tuple[str, str]]() if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) + self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) + self.merge_loops() self.finalize_multi_template_buffers() if config.combo_kernels: @@ -2366,7 +2368,6 @@ def add_user( for node in self.nodes: log.debug("scheduling %s", node.node) - # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately assert node.node is not None diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5f9ce0b814eba..7b3f495382f76 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -69,13 +69,20 @@ "inductor_autotune_lookup_table", ] +from torch.fx.experimental.symbolic_shapes import ( + free_symbols, + free_unbacked_symbols, + IterateExprs, + ShapeEnv, +) + + if TYPE_CHECKING: from collections.abc import Iterable, Sequence, ValuesView from torch import SymBool, SymFloat, SymInt from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.fx import GraphModule - from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.node import Node from .codegen.common import WorkspaceArg @@ -3359,3 +3366,10 @@ def aoti_model_name_from_config() -> str: model_name = config.aot_inductor.model_name_for_generated_files model_name = "aoti_model" if model_name is None else model_name return model_name + + +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index ae87e0e17fb37..2933a37c37fd8 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5553,39 +5553,6 @@ def meta_zeros( ) -@register_meta(aten.select.int) -def meta_select(self, dim, index): - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - ndim = self.dim() - torch._check_index( - ndim != 0, - lambda: "select() cannot be applied to a 0-dim tensor.", - ) - - dim = dim if dim >= 0 else dim + ndim - size = self.size(dim) - - torch._check_index( - not ( - guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) - ), - lambda: f"select(): index {index} out of range for tensor of size " - f"{self.size()} at dimension {dim}", - ) - - index = index if index >= 0 else index + size - - new_size = list(self.size()) - new_stride = list(self.stride()) - - new_storage_offset = self.storage_offset() + index * new_stride[dim] - del new_size[dim] - del new_stride[dim] - - return self.as_strided(new_size, new_stride, new_storage_offset) - - @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index e802d9a4389d4..e2e24cb59bc27 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -359,6 +359,48 @@ def unique2( return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) +@register_op_impl(aten.select.int) +def meta_select(fake_mode, func, self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if self.is_sparse: + return NotImplemented + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = None + if guard_or_false(index >= 0): + new_storage_offset = self.storage_offset() + index * new_stride[dim] + elif guard_or_false(index < 0): + new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim] + + if new_storage_offset is None: + if fake_mode.shape_env is None or ( + not fake_mode.shape_env.allow_scalar_outputs + and not fake_mode.allow_scalar_outputs + ): + raise DataDependentOutputException(func) + + # index is data-dependent, we do not know which index we are accessing it could be index or index+size! + # we assign a new data-dependent symbol for the storage offset. + new_storage_offset = fake_mode.shape_env.create_unbacked_symint() + + del new_size[dim] + del new_stride[dim] + assert new_storage_offset is not None + return self.as_strided(new_size, new_stride, new_storage_offset) + + @register_op_impl(aten.unique_dim.default) def unique_dim( fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e38e5f777d669..4814e2daefe33 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1282,6 +1282,7 @@ def compute_unbacked_bindings( return None fs = shape_env.pending_fresh_unbacked_symbols + pending = set(fs) if not pending: return None @@ -4809,6 +4810,7 @@ def create_unbacked_symfloat(self) -> SymFloat: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): + print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 38c64c527aff0..bb71a25971da7 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -461,6 +461,7 @@ def go(node, keypath): ), keypath[2:], ) + return go( graph.call_method( keypath[0].name, (node, keypath[1].idx) @@ -468,6 +469,15 @@ def go(node, keypath): keypath[2:], ) elif isinstance(keypath[0], CallMethodKey): + if keypath[0].name == "storage_offset": + return go( + graph.call_function( + torch.ops.aten.sym_storage_offset.default, + (node,), + ), + keypath[1:], + ) + return go( graph.call_method(keypath[0].name, (node,)), keypath[1:] ) From 1a4268b8113d5160d71225bab980f03c2318a0a4 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 16 Jul 2025 15:07:08 -0700 Subject: [PATCH 164/457] [BE] remove torch deploy - conditionals (#158288) This PR is part of the work to deprecate torch::deploy in OSS. Effectively it does 3 things to get started. 1. Remove test_deploy_interaction as we no longer need to worry about this 2. Remove all torch._running_with_deploy checks and use the False path always (surfaced 1) 3. Remove `USE_DEPLOY` and switch to the default path always Note: MyPy does fail on a bunch of things here as a bunch of older files are touched. It may be better to fix these things on a separate PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/158288 Approved by: https://github.com/albanD --- test/test_custom_ops.py | 56 --- test/test_sparse_csr.py | 8 +- torch/__init__.py | 47 +-- .../_dynamo/_trace_wrapped_higher_order_op.py | 81 ++-- torch/_dynamo/trace_rules.py | 1 - torch/_inductor/test_operators.py | 43 +-- torch/_library/custom_ops.py | 4 - torch/_library/utils.py | 10 - torch/_ops.py | 3 - torch/_utils_internal.py | 12 +- torch/csrc/lazy/python/init.cpp | 4 - torch/csrc/utils/python_dispatch.cpp | 9 - torch/cuda/__init__.py | 3 - torch/distributed/_functional_collectives.py | 132 +++---- torch/distributed/_tools/fake_collectives.py | 7 +- .../fsdp/_fully_shard/_fsdp_common.py | 36 +- .../fsdp/_fully_shard/_fsdp_param.py | 3 +- torch/distributed/tensor/_collective_utils.py | 31 +- torch/library.py | 21 - torch/utils/__init__.py | 8 +- torch/utils/_import_utils.py | 8 +- torch/utils/collect_env.py | 362 ++++++++++-------- 22 files changed, 382 insertions(+), 507 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 1f3e670c15396..d4f5c2a7c0523 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -544,62 +544,6 @@ def test_assert_raises_regex(self, device): class TestCustomOp(CustomOpTestCaseBase): test_ns = "_test_custom_op" - def test_deploy_interaction(self): - # run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy - script = """ -import torch -torch._running_with_deploy = lambda: True - -# creating the library is a no-op, so you can DEF multiple times -m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 -m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 - -m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901 - -# define is a no-op -m.define("foobarbaz9996(Tensor x) -> Tensor") -assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop" - -def sin_override(x): - raise AssertionError("m.impl should have been a noop") - -# impl is a no-op -m.impl("sin", sin_override, "CompositeImplicitAutograd") -x = torch.randn(3) -y = torch.sin(x) - -# should be a no-op -@torch.library.custom_op("mylib::foobar", mutates_args={}) -def foobar(x: torch.Tensor) -> torch.Tensor: - return x.sin() - -# should be a no-op -@foobar.register_fake -def _(x): - return torch.empty_like(x) - -# should be a no-op -m2.define("foobarbaz9996(Tensor x) -> Tensor") - -# should be a no-op -@torch.library.register_fake("mylib4392::foobarbaz9996") -def _(x): - return torch.empty_like(x) - """ - script = script.strip() - env = os.environ.copy() - try: - subprocess.check_output( - [sys.executable, "-c", script], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)), - env=env, - ) - except subprocess.CalledProcessError as e: - self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8"))) - @requires_compile def test_functionalize_error(self): with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index cc313c586a090..8fb490e1b5bc7 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3603,8 +3603,8 @@ def test_triton_bsr_softmax(self, device, dtype): @onlyCUDA @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) - @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), - "Skipped for deploy and internal with remote GPUs") + @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU), + "Skipped for internal with remote GPUs") def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): from functools import partial from torch.sparse._triton_ops import bsr_dense_mm @@ -3680,8 +3680,8 @@ def kernel_impl(*args, **kwargs): @onlyCUDA @dtypes(torch.half) - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), - "Skipped for deploy and internal with remote GPUs") + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, + "Skipped for internal with remote GPUs") def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/__init__.py b/torch/__init__.py index 99cb83db84b81..f124d1a5a1d6c 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -34,20 +34,10 @@ ) from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs -from . import version - if TYPE_CHECKING: from .types import Device, IntLikeType - -# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy -# reliable way we have to detect if we're running within deploy. -# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950 -def _running_with_deploy() -> builtins.bool: - return sys.modules.get("torch._meta_registrations", None) is object - - from torch._utils import ( _functionalize_sync as _sync, _import_dotted_name, @@ -60,14 +50,9 @@ def _running_with_deploy() -> builtins.bool: USE_GLOBAL_DEPS, USE_RTLD_GLOBAL_WITH_LIBTORCH, ) +from torch.torch_version import __version__ as __version__ -# TODO(torch_deploy) figure out how to freeze version.py in fbcode build -if _running_with_deploy(): - __version__ = "torch-deploy-1.8" -else: - from torch.torch_version import __version__ as __version__ - __all__ = [ "BoolStorage", "BoolTensor", @@ -317,7 +302,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: # See Note [Global dependencies] def _load_global_deps() -> None: - if _running_with_deploy() or platform.system() == "Windows": + if platform.system() == "Windows": return # Determine the file extension based on the platform @@ -381,7 +366,7 @@ def _load_global_deps() -> None: if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( - _running_with_deploy() or platform.system() != "Windows" + platform.system() != "Windows" ): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: @@ -2082,7 +2067,7 @@ def _dtype(self): # Shared memory manager needs to know the exact location of manager executable def _manager_path(): - if _running_with_deploy() or platform.system() == "Windows": + if platform.system() == "Windows": return b"" path = get_file_path("torch", "bin", "torch_shm_manager") prepare_multiprocessing_environment(get_file_path("torch")) @@ -2687,21 +2672,21 @@ def _register_device_module(device_type, module): # Register MPS specific decomps torch.backends.mps._init() -if not _running_with_deploy(): - from torch import compiler as compiler +from torch import compiler as compiler + - class _TritonLibrary: - lib = torch.library.Library("triton", "DEF") - ops_table: dict[tuple[str, str], _Callable] = {} +class _TritonLibrary: + lib = torch.library.Library("triton", "DEF") + ops_table: dict[tuple[str, str], _Callable] = {} - @classmethod - def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): - if (op_key, dispatch_key) not in cls.ops_table: - cls.lib.define(full_schema) - cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) - cls.ops_table[(op_key, dispatch_key)] = op_impl + @classmethod + def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): + if (op_key, dispatch_key) not in cls.ops_table: + cls.lib.define(full_schema) + cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) + cls.ops_table[(op_key, dispatch_key)] = op_impl - return cls.ops_table[(op_key, dispatch_key)] + return cls.ops_table[(op_key, dispatch_key)] # Deprecated attributes diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 8fab0b2005491..17b664fc5e0ed 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -49,47 +49,46 @@ __all__ = ["trace_wrapped"] -if not torch._running_with_deploy(): - # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore - - @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] - def zeros_and_scatter( - shape: list[int], - indices: list[Tensor], - vals: Tensor, - ) -> Tensor: - """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" - grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) - return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) - - @zeros_and_scatter.register_fake # type: ignore[misc] - def _( - shape: list[int], - indices: list[Tensor], - vals: Tensor, - ) -> Tensor: - return vals.new_empty(shape) - - @zeros_and_scatter.register_vmap # type: ignore[misc] - def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] - """The batching rule is special in that it returns a tensor that is not batched""" - indices_indims = indims[1] - expanded_indices = [] - for idx, idx_indim in zip(indices, indices_indims): - # The index is not a being batched, we should unsqueeze and expand to val - if idx_indim is None: - expanded_indices.append(idx.expand(value.shape)) - else: - # the index is being part of the vmap batch, it should be the same size as val - assert idx.shape == value.shape - expanded_indices.append(idx) - - out = torch.ops.flex_lib.zeros_and_scatter( - shape, - expanded_indices, - value, - ) - return out, None +@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] +def zeros_and_scatter( + shape: list[int], + indices: list[Tensor], + vals: Tensor, +) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) + + +@zeros_and_scatter.register_fake # type: ignore[misc] +def _( + shape: list[int], + indices: list[Tensor], + vals: Tensor, +) -> Tensor: + return vals.new_empty(shape) + + +@zeros_and_scatter.register_vmap # type: ignore[misc] +def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.flex_lib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None class ModIndex(torch.autograd.Function): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 0883525f47282..8005b08d24465 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2408,7 +2408,6 @@ "torch._lowrank.svd_lowrank", "torch._preload_cuda_deps", "torch._register_device_module", - "torch._running_with_deploy", "torch._utils._dummy_type", "torch._utils._flatten_dense_tensors", "torch._utils._unflatten_dense_tensors", diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index bf49f3f5d04a1..d3d2705f8c788 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -5,25 +5,24 @@ from torch.autograd import Function -if not torch._running_with_deploy(): - _test_lib_def = torch.library.Library("_inductor_test", "DEF") - _test_lib_def.define( - "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag - ) - - _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") - for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): - _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) - - class Realize(Function): - @staticmethod - def forward(ctx: object, x: Tensor) -> Tensor: - return torch.ops._inductor_test.realize(x) - - @staticmethod - # types need to stay consistent with _SingleLevelFunction - def backward(ctx: Any, *grad_output: Any) -> Any: - return grad_output[0] - - def realize(x: Tensor) -> Tensor: - return Realize.apply(x) +_test_lib_def = torch.library.Library("_inductor_test", "DEF") +_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) + +_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") +for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + +class Realize(Function): + @staticmethod + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) + + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + +def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 547d305c47afd..1d8d0fc5377b1 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -595,10 +595,6 @@ def register_autograd( self._setup_context_fn = setup_context def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: - if torch._running_with_deploy(): - utils.warn_deploy(stacklevel=5) - return - lib = self._lib schema_str = self._name + self._schema cpp_schema = _C.parse_schema(schema_str) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 17e128bdbe0f3..9403185204520 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -2,7 +2,6 @@ import dataclasses import inspect import sys -import warnings from collections.abc import Iterable, Iterator from typing import Any, Callable, Union @@ -12,15 +11,6 @@ from torch._ops import OpOverload -def warn_deploy(stacklevel=3): - warnings.warn( - "Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy - "Please instead use C++ custom operator registration APIs.", - RuntimeWarning, - stacklevel=stacklevel, - ) - - @dataclasses.dataclass class Kernel: """Models a (function, source location)""" diff --git a/torch/_ops.py b/torch/_ops.py index 9995aafb249a5..e51343cff972c 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1478,9 +1478,6 @@ def load_library(self, path): Args: path (str): A path to a shared library to load. """ - if torch._running_with_deploy(): - return - path = _utils_internal.resolve_library_path(path) with dl_open_guard(): # Import the shared library into the process, thus running its diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 1833b918e180e..e067a587497b1 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -33,16 +33,10 @@ # use is the FB build environment, where this source file is replaced # by an equivalent. -if torch._running_with_deploy(): - # __file__ is meaningless in the context of frozen torch used in torch deploy. - # setting empty torch_parent should allow below functions to operate without crashing, - # but it's unclear if there is a valid use case for them in the context of deploy. - torch_parent = "" +if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) else: - if os.path.basename(os.path.dirname(__file__)) == "shared": - torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - else: - torch_parent = os.path.dirname(os.path.dirname(__file__)) + torch_parent = os.path.dirname(os.path.dirname(__file__)) def get_file_path(*path_components: str) -> str: diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index f2b14cbfd7bb4..4807aa6a4c7d1 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -331,13 +331,9 @@ void initLazyBindings(PyObject* module) { // So far this problem has only been observed internally, so we will just // block it off there. -#if !(defined(USE_DEPLOY)) - // When libtorch_python is loaded, we register the python frame getter // otherwise, debug util simply omits python frames GetPythonFramesFunction() = GetPythonFrames; - -#endif // USE_DEPLOY } } // namespace torch::lazy diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 34fbfec49c919..b2b0e848a7e79 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -187,15 +187,6 @@ class PythonKernelHolder : public c10::OperatorKernel { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; - // Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy - // so stop forcing hermetic mode unconditionally in all situations when - // you're using multipy. // codespell:ignore multipy - // Eventually just delete this entirely. (Note that you may break - // multipy anyway this way with dispatcher // codespell:ignore multipy - // registered functions that require hermetic to be off.) -#if defined(USE_DEPLOY) - EnableHermeticPyObject g2; -#endif auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto func = py::reinterpret_borrow(func_.ptr(getPyInterpreter())); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4602aa1ee172e..6a8fc7dfb12ef 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1693,9 +1693,6 @@ def __call__(self, *args, **kwargs): def _register_triton_kernels(): - if torch._running_with_deploy(): - return - @_WrappedTritonKernel def kernel_impl(*args, **kwargs): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0ffae8a9c9fe3..73cdcf4217895 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -19,22 +19,16 @@ from torch.utils._pytree import tree_map_only # type: ignore[no-redef] -if torch._running_with_deploy(): +try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling +except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) - def is_torchdynamo_compiling(): - """Can't import torchdynamo in torchdeploy builds currently.""" + def is_torchdynamo_compiling(): # type: ignore[misc] + return False return False - -else: - try: - from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling - except Exception: - warnings.warn( - "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" - ) - - def is_torchdynamo_compiling(): - return False """ @@ -987,66 +981,58 @@ def _reduce_scatter_tensor_coalesced_native_meta( ] -if not torch._running_with_deploy(): - # Library MUST be defined at module scope or it doesn't work - # Creating a "DEF" Library always crashes torch::deploy so we create our - # Library instances here guarded against running inside it - lib_impl = torch.library.Library("_c10d_functional", "IMPL") - lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") - lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") - lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") - lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") - lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") - lib_impl.impl( - "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" - ) - lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") - lib_impl.impl( - "all_gather_into_tensor_coalesced", - _all_gather_into_tensor_coalesced_native_meta, - "Meta", - ) - lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") - lib_impl.impl( - "reduce_scatter_tensor_coalesced", - _reduce_scatter_tensor_coalesced_native_meta, - "Meta", - ) - lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") - lib_impl.impl("broadcast", _broadcast_meta, "Meta") - lib_impl.impl("broadcast_", _broadcast__meta, "Meta") - - # mark these ops has side effect so that they won't be removed by DCE - torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) - torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) - - # Register legacy ops for backward compatibility - # TODO(yifu): remove these in functional collective beta release - legacy_lib = torch.library.Library("c10d_functional", "DEF") - legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") - ops_defs = [ - "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "wait_tensor(Tensor self) -> Tensor", - "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", - "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", - "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 - ] - - my_module = sys.modules[__name__] - for op_def in ops_defs: - op_name = op_def[0 : op_def.index("(")] - backend_impl = getattr(fun_col_impl, f"_{op_name}") - legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) - legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") +# Library MUST be defined at module scope or it doesn't work +lib_impl = torch.library.Library("_c10d_functional", "IMPL") +lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") +lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") +lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") +lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") +lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" +) +lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") +lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") +lib_impl.impl("broadcast", _broadcast_meta, "Meta") +lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + +# mark these ops has side effect so that they won't be removed by DCE +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) + +# Register legacy ops for backward compatibility +# TODO(yifu): remove these in functional collective beta release +legacy_lib = torch.library.Library("c10d_functional", "DEF") +legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") +ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 +] -else: - warnings.warn( - "PyTorch Distributed functional collectives do not work with torch::deploy." - ) +my_module = sys.modules[__name__] +for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") """ diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index f6cb23a06b671..3b201b395334b 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -63,10 +63,9 @@ def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), } -if not torch._running_with_deploy(): - lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 - for op, meta_func in _META_FUNCTIONS.items(): - lib_impl.impl(op, meta_func, "Meta") +lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 +for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") # List of collective operation functions including functional collectives # Note: The following collectives might be deprecated soon hence not adding them diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index fdcf32e22a338..b599f48d77d1d 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -15,32 +15,24 @@ _compiled_autograd_enabled: bool = False -if torch._running_with_deploy(): - def detect_compiled_autograd(): - pass +def detect_compiled_autograd(): + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca - def compiled_autograd_enabled(): - return False + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) -else: - def detect_compiled_autograd(): - assert not torch.compiler.is_compiling(), ( - "`detect_compiled_autograd()` is designed to be called in eager mode" - ) - global _compiled_autograd_enabled - import torch._dynamo.compiled_autograd as ca - - _compiled_autograd_enabled = ( - ca.compiled_autograd_enabled - or ca.compiled_autograd_enabled_force_eager - or ca.in_compiled_autograd_region - ) - - def compiled_autograd_enabled(): - global _compiled_autograd_enabled - return _compiled_autograd_enabled +def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 7649c32ec1c0e..b7c8f4ea7c78a 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -140,8 +140,7 @@ def copy__functionalize(tensor, data): torch.ops.fsdp.copy_.default(tensor_inner, data_inner) -if not torch._running_with_deploy(): - torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) +torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) class ShardedState(Enum): diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 01505fddd0fd1..a1e38aec651bf 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -25,26 +25,17 @@ logger = logging.getLogger(__name__) -if not torch._running_with_deploy(): - - @torch.library.register_fake("_dtensor::shard_dim_alltoall") - def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): - group_size = _get_group_size_by_name(group_name) - stacked_list = [torch.empty_like(input) for _ in range(group_size)] - group = _resolve_process_group(group_name) - group_rank = get_group_rank(group, get_rank()) - - return ( - torch.cat(stacked_list, dim=gather_dim) - .chunk(group_size, dim=shard_dim)[group_rank] - .contiguous() - ) - -else: - import warnings - - warnings.warn( - "PyTorch Distributed functional collectives do not work with torch::deploy." +@torch.library.register_fake("_dtensor::shard_dim_alltoall") +def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() ) diff --git a/torch/library.py b/torch/library.py index a30cdb9bb48ac..23a7acf1662c5 100644 --- a/torch/library.py +++ b/torch/library.py @@ -102,9 +102,6 @@ def __init__(self, ns, kind, dispatch_key=""): ns, " is a reserved namespace. Please try creating a library with another name.", ) - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return frame = traceback.extract_stack(limit=3)[0] filename, lineno = frame.filename, frame.lineno @@ -156,9 +153,6 @@ def define(self, schema, alias_analysis="", *, tags=()): >>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # AliasAnalysis type in C++ @@ -191,9 +185,6 @@ def define(self, schema, alias_analysis="", *, tags=()): def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): r"""Registers the fake impl for an operator defined in the library.""" - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return source = torch._library.utils.get_source(_stacklevel + 1) frame = sys._getframe(_stacklevel) @@ -237,9 +228,6 @@ def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): If it is a TorchDispatchMode, we expect fn to have the following signature: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return qualname = f"{self.ns}::{op_name}" entry = torch._library.simple_registry.singleton.find(qualname) @@ -259,9 +247,6 @@ def _impl_with_aoti_compile(self, op_name, dispatch_key=""): >>> my_lib = Library("aten", "IMPL") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if dispatch_key == "": dispatch_key = self.dispatch_key @@ -324,9 +309,6 @@ def impl( >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if not callable(fn): raise TypeError( @@ -409,9 +391,6 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast") """ - if torch._running_with_deploy(): - _library.utils.warn_deploy() - return if dispatch_key == "": dispatch_key = self.dispatch_key diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 23188bba9b800..1c3ec15790063 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -29,13 +29,7 @@ def set_module(obj, mod): obj.__module__ = mod -if torch._running_with_deploy(): - # not valid inside torch_deploy interpreter, no paths exists for frozen modules - cmake_prefix_path = None -else: - cmake_prefix_path = _osp.join( - _osp.dirname(_osp.dirname(__file__)), "share", "cmake" - ) +cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake") def swap_tensors(t1, t2): diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index dc2d7d4f0382c..240f92acacb9d 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -3,8 +3,6 @@ from types import ModuleType from typing import Optional -import torch - def _check_module_exists(name: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** @@ -22,11 +20,7 @@ def _check_module_exists(name: str) -> bool: @functools.lru_cache def dill_available() -> bool: - return ( - _check_module_exists("dill") - # dill fails to import under torchdeploy - and not torch._running_with_deploy() - ) + return _check_module_exists("dill") @functools.lru_cache diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 9bb80c65076b8..c6473220bc00a 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -6,49 +6,53 @@ import datetime import json import locale +import os import re import subprocess import sys -import os -from typing import cast as _cast from collections import namedtuple +from typing import cast as _cast try: import torch + TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information -SystemEnv = namedtuple('SystemEnv', [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'is_xpu_available', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', -]) +SystemEnv = namedtuple( + "SystemEnv", + [ + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "is_xpu_available", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + ], +) COMMON_PATTERNS = [ "torch", @@ -116,12 +120,13 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=shell) + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' + if get_platform() == "win32": + enc = "oem" else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) @@ -147,18 +152,19 @@ def run_and_parse_first_match(run_lambda, command, regex): return None return match.group(1) + def run_and_return_first_line(run_lambda, command): """Run command using run_lambda and returns first line if output is not empty.""" rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split('\n')[0] + return out.split("\n")[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') + conda = os.environ.get("CONDA_EXE", "conda") out = run_and_read_all(run_lambda, "{} list".format(conda)) if out is None: return out @@ -166,32 +172,40 @@ def get_conda_packages(run_lambda, patterns=None): return "\n".join( line for line in out.splitlines() - if not line.startswith("#") - and any(name in line for name in patterns) + if not line.startswith("#") and any(name in line for name in patterns) ) + def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") + def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -204,42 +218,42 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) + return re.sub(uuid_regex, "", out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') + l = os.environ.get("CUDNN_LIBRARY") if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split('\n'): + for fn in out.split("\n"): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -249,18 +263,20 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux - smi = 'nvidia-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -411,7 +427,9 @@ def get_intel_gpu_detected(run_lambda): if device_count == 0: return "N/A" - devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)] + devices = [ + f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count) + ] return "\n".join(devices) @@ -490,11 +508,12 @@ def get_intel_gpu_detected(run_lambda): # ProcessorType=3 # Revision=27142 + def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": rc, out, err = run_lambda( 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ @@ -514,9 +533,9 @@ def get_cpu_info(run_lambda): lst.append(out) lst.append(str(e)) out = "\n".join(lst) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' + cpu_info = "None" if rc == 0: cpu_info = out else: @@ -525,20 +544,20 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") def get_windows_version(run_lambda): @@ -556,39 +575,43 @@ def get_windows_version(run_lambda): def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) def get_os(run_lambda): from platform import machine + platform = get_platform() if platform in ["win32", "cygwin"]: return get_windows_version(run_lambda) - if platform == 'darwin': + if platform == "darwin": version = get_mac_version(run_lambda) if version is None: return None - return 'macOS {} ({})'.format(version, machine()) + return "macOS {} ({})".format(version, machine()) - if platform == 'linux': + if platform == "linux": # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) - return '{} ({})'.format(platform, machine()) + return "{} ({})".format(platform, machine()) # Unknown platform return platform @@ -596,14 +619,16 @@ def get_os(run_lambda): def get_python_platform(): import platform + return platform.platform() def get_libc_version(): import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) def get_pip_packages(run_lambda, patterns=None): @@ -611,35 +636,35 @@ def get_pip_packages(run_lambda, patterns=None): if patterns is None: patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' + pip_version = "pip3" if sys.version_info.major == 3 else "pip" - os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' + os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" # People generally have pip as `pip` or `pip3` # But here it is invoked as `python -mpip` - out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) + out = run_and_read_all( + run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"] + ) if out is None: return pip_version, out - filtered_out = '\n'.join( - line - for line in out.splitlines() - if any(name in line for name in patterns) + filtered_out = "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) ) return pip_version, filtered_out def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") if not ca_config: - ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "") return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') + config = os.environ.get("CUDA_MODULE_LOADING", "") return config else: return "N/A" @@ -648,10 +673,12 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" + def get_env_info(): """ Collects environment information to aid in debugging. @@ -678,26 +705,31 @@ def get_env_info(): cuda_version_str = torch.version.cuda xpu_available_str = str(torch.xpu.is_available()) if torch.xpu.is_available(): - xpu_available_str = f'{xpu_available_str}\n' + \ - f'XPU used to build PyTorch: {torch.version.xpu}\n' + \ - f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \ - f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \ - f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}' - if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + xpu_available_str = ( + f"{xpu_available_str}\n" + + f"XPU used to build PyTorch: {torch.version.xpu}\n" + + f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n" + + f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n" + + f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}" + ) + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" else: # HIP version + def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' + return _lst[0] if _lst else "N/A" - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment] + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" sys_version = sys.version.replace("\n", " ") @@ -706,7 +738,9 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -732,6 +766,7 @@ def get_version_or_na(cfg, prefix): cpu_info=get_cpu_info(run_lambda), ) + env_info_fmt = """ PyTorch version: {torch_version} Is debug build: {is_debug_build} @@ -767,14 +802,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): - def replace_nones(dct, replacement='Could not collect'): + def replace_nones(dct, replacement="Could not collect"): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true='Yes', false='No'): + def replace_bools(dct, true="Yes", false="No"): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -782,42 +817,48 @@ def replace_bools(dct, true='Yes', false='No'): dct[key] = false return dct - def prepend(text, tag='[prepend]'): - lines = text.split('\n') + def prepend(text, tag="[prepend]"): + lines = text.split("\n") updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) + return "\n".join(updated_lines) - def replace_if_empty(text, replacement='No relevant packages'): + def replace_if_empty(text, replacement="No relevant packages"): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' + mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' + mutable_dict["cuda_compiled_version"] = "None" # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -826,18 +867,20 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], - '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], - '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -861,18 +904,29 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + dumps = [ + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) + ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) print(msg, file=sys.stderr) - -if __name__ == '__main__': +if __name__ == "__main__": main() From a6de309ca15cda6b2792fc74e82814dc8d2f9dd9 Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 16 Jul 2025 15:07:08 -0700 Subject: [PATCH 165/457] [BE] Remove torch deploy | remove torch deploy specific files (#158290) This PR removes specific files found in pytorch which are only used for torch::deploy. This is mostly testing code and a debugger. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158290 Approved by: https://github.com/albanD ghstack dependencies: #158288 --- docs/source/deploy.md | 8 - test/test_deploy.py | 43 ----- tools/lldb/deploy_debugger.py | 38 ----- torch/_deploy.py | 104 ------------ torch/csrc/deploy/README.md | 2 - torch/utils/_freeze.py | 292 ---------------------------------- 6 files changed, 487 deletions(-) delete mode 100644 docs/source/deploy.md delete mode 100644 test/test_deploy.py delete mode 100644 tools/lldb/deploy_debugger.py delete mode 100644 torch/_deploy.py delete mode 100644 torch/csrc/deploy/README.md delete mode 100644 torch/utils/_freeze.py diff --git a/docs/source/deploy.md b/docs/source/deploy.md deleted file mode 100644 index ef5131717bf7b..0000000000000 --- a/docs/source/deploy.md +++ /dev/null @@ -1,8 +0,0 @@ ---- -orphan: true ---- - -# torch::deploy has been moved to pytorch/multipy - - -``torch::deploy`` has been moved to its new home at [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy). diff --git a/test/test_deploy.py b/test/test_deploy.py deleted file mode 100644 index b852802c0c20f..0000000000000 --- a/test/test_deploy.py +++ /dev/null @@ -1,43 +0,0 @@ -# Owner(s): ["oncall: package/deploy"] - -import textwrap -import types - -from torch.testing._internal.common_utils import run_tests, TestCase -from torch.utils._freeze import Freezer, PATH_MARKER - - -class TestFreezer(TestCase): - """Tests the freeze.py script""" - - def test_compile_string(self): - freezer = Freezer(True) - code_str = textwrap.dedent( - """ - class MyCls: - def __init__(self) -> None: - pass - """ - ) - co = freezer.compile_string(code_str) - num_co = 0 - - def verify_filename(co: types.CodeType): - nonlocal num_co - - if not isinstance(co, types.CodeType): - return - - self.assertEqual(PATH_MARKER, co.co_filename) - num_co += 1 - - for nested_co in co.co_consts: - verify_filename(nested_co) - - verify_filename(co) - # there is at least one nested code object besides the top level one - self.assertTrue(num_co >= 2) - - -if __name__ == "__main__": - run_tests() diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py deleted file mode 100644 index 7a28c72a6caf2..0000000000000 --- a/tools/lldb/deploy_debugger.py +++ /dev/null @@ -1,38 +0,0 @@ -import lldb # type: ignore[import] - - -# load into lldb instance with: -# command script import tools/lldb/deploy_debugger.py - -target = lldb.debugger.GetSelectedTarget() -bp = target.BreakpointCreateByRegex("__deploy_register_code") -bp.SetScriptCallbackBody( - """\ -process = frame.thread.GetProcess() -target = process.target -symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress() -info_addr = symbol_addr.GetLoadAddress(target) -e = lldb.SBError() -ptr_size = 8 -str_addr = process.ReadPointerFromMemory(info_addr, e) -file_addr = process.ReadPointerFromMemory(info_addr + ptr_size, e) -file_size = process.ReadPointerFromMemory(info_addr + 2*ptr_size, e) -load_bias = process.ReadPointerFromMemory(info_addr + 3*ptr_size, e) -name = process.ReadCStringFromMemory(str_addr, 512, e) -r = process.ReadMemory(file_addr, file_size, e) -from tempfile import NamedTemporaryFile -from pathlib import Path -stem = Path(name).stem -with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf: - tf.write(r) - print("torch_deploy registering debug information for ", tf.name) - cmd1 = f"target modules add {tf.name}" - # print(cmd1) - lldb.debugger.HandleCommand(cmd1) - cmd2 = f"target modules load -f {tf.name} -s {hex(load_bias)}" - # print(cmd2) - lldb.debugger.HandleCommand(cmd2) - -return False -""" -) diff --git a/torch/_deploy.py b/torch/_deploy.py deleted file mode 100644 index 0443a2447d00d..0000000000000 --- a/torch/_deploy.py +++ /dev/null @@ -1,104 +0,0 @@ -# mypy: allow-untyped-defs -import io - -import torch -from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer -from torch.package._package_pickler import create_pickler -from torch.package._package_unpickler import PackageUnpickler -from torch.serialization import _maybe_decode_ascii - - -def _save_storages(importer, obj): - serialized_storages = [] - serialized_dtypes = [] - - importer = importer if isinstance(importer, torch.package.PackageImporter) else None - importers: Importer - if importer is not None: - importers = OrderedImporter(importer, sys_importer) - else: - importers = sys_importer - - def persistent_id(obj): - if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): - if isinstance(obj, torch.storage.TypedStorage): - # TODO: Once we decide to break serialization FC, we can - # remove this case - dtype = obj.dtype - else: - dtype = torch.uint8 - - serialized_storages.append(obj) - serialized_dtypes.append(dtype) - return ("storage", len(serialized_storages) - 1) - - if hasattr(obj, "__reduce_deploy__"): - if _serialized_reduces.get(id(obj)) is None: - _serialized_reduces[id(obj)] = ( - "reduce_deploy", - id(obj), - *obj.__reduce_deploy__(importers), - ) - return _serialized_reduces[id(obj)] - - return None - - # Write the pickle data for `obj` - data_buf = io.BytesIO() - pickler = create_pickler(data_buf, importers) - pickler.persistent_id = persistent_id - pickler.dump(obj) - data_value = data_buf.getvalue() - return ( - data_value, - serialized_storages, - serialized_dtypes, - importer.zip_reader if importer else None, - ) - - -def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): - def persistent_load(saved_id): - assert isinstance(saved_id, tuple) - typename = _maybe_decode_ascii(saved_id[0]) - data = saved_id[1:] - - if typename == "storage": - # TODO: Once we decide to break serialization FC, we can - # stop wrapping with TypedStorage - storage = serialized_storages[data[0]] - dtype = serialized_dtypes[data[0]] - return torch.storage.TypedStorage( - wrap_storage=storage.untyped(), dtype=dtype - ) - - if typename == "reduce_deploy": - reduce_id, func, args = data - if reduce_id not in _loaded_reduces: - _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) - return _loaded_reduces[reduce_id] - - return None - - importer: Importer - if zip_reader is not None: - importer = OrderedImporter(_get_package(zip_reader), sys_importer) - else: - importer = sys_importer - - unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) - unpickler.persistent_load = persistent_load # type: ignore[method-assign] - result = _deploy_objects[id] = unpickler.load() - return result - - -def _get_package(zip_reader): - if zip_reader not in _raw_packages: - _raw_packages[zip_reader] = PackageImporter(zip_reader) - return _raw_packages[zip_reader] - - -_raw_packages: dict = {} -_deploy_objects: dict = {} -_serialized_reduces: dict = {} -_loaded_reduces: dict = {} diff --git a/torch/csrc/deploy/README.md b/torch/csrc/deploy/README.md deleted file mode 100644 index 2d40ca8361ff4..0000000000000 --- a/torch/csrc/deploy/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# torch::deploy has been moved to pytorch/multipy -Please check out [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy) to find the new home for torch::deploy. diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py deleted file mode 100644 index 8696065adb9f9..0000000000000 --- a/torch/utils/_freeze.py +++ /dev/null @@ -1,292 +0,0 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs -""" -Freeze Python packages. - - - - -Freezing makes it possible to ship arbitrary Python modules as part of a C++ -library. The Python source of the module is compiled to bytecode and written -to `.c` files, to be imported by Python's built-in FrozenImporter. - -In a normal Python installation, FrozenImporter is only used to bootstrap the -initialization of the import machinery. Python's importers are defined in -Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be -retrieved before any importers are available. Freezing the module bytecode -resolves this circular dependency. - -This script will freeze the Python standard library. It produces two things: -- Bytecode files: A set of `.c` that define C variables containing Python bytecode. -- Main file: A `main.c` file listing all of these modules in the right form to be - consumed by FrozenImporter. - -The library that wishes to these modules make them available to the local -Python instance by extending `PyImport_FrozenModules` appropriately (see -https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules). -""" - -import argparse -import functools -import itertools -import marshal -import os -import types -from dataclasses import dataclass -from pathlib import Path - - -PATH_MARKER = "" -MAIN_INCLUDES = """#include - -""" - -MAIN_PREFIX_TEMPLATE = """ -// Compiled standard library modules. These should be appended to the existing -// `PyImport_FrozenModules` that ships with CPython. -struct _frozen {}[] = {{ -""" - -FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules") - -MAIN_SUFFIX = """\ - {0, 0, 0} /* sentinel */ -}; -""" - -# Exclude some standard library modules to: -# 1. Slim down the final frozen lib. -# 2. Remove functionality we don't want to support. -DENY_LIST = [ - # Interface to unix databases - "dbm", - # ncurses bindings (terminal interfaces) - "curses", - # Tcl/Tk GUI - "tkinter", - "tkinter", - # Tests for the standard library - "test", - "tests", - "idle_test", - "__phello__.foo.py", - # importlib frozen modules. These are already baked into CPython. - "_bootstrap.py", - "_bootstrap_external.py", -] - -NUM_BYTECODE_FILES = 5 - - -def indent_msg(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - args[0].indent += 1 - ret = fn(*args, **kwargs) - args[0].indent -= 1 - return ret - - return wrapper - - -@dataclass -class FrozenModule: - # The fully qualified module name, e.g. 'foo.bar.baz' - module_name: str - # The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz' - c_name: str - # The size of the C variable. Negative if this module is a package. - size: int - # The frozen bytecode - bytecode: bytes - - -class Freezer: - def __init__(self, verbose: bool): - self.frozen_modules: list[FrozenModule] = [] - self.indent: int = 0 - self.verbose: bool = verbose - - def msg(self, path: Path, code: str): - if not self.verbose: - return - # P: package dir - # F: python file - # S: skipped (not a package dir) - # X: skipped (deny-listed) - # N: skipped (not a python file) - print(" " * self.indent, end="") - print(f"{code} {path}") - - def write_bytecode(self, install_root): - """ - Write the `.c` files containing the frozen bytecode. - - Shared frozen modules evenly across the files. - """ - bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)] - bytecode_files = [ - open(os.path.join(install_root, name), "w") for name in bytecode_file_names - ] - it = itertools.cycle(bytecode_files) - for m in self.frozen_modules: - self.write_frozen(m, next(it)) - - for f in bytecode_files: - f.close() - - def write_main(self, install_root, oss, symbol_name): - """Write the `main.c` file containing a table enumerating all the frozen modules.""" - with open(os.path.join(install_root, "main.c"), "w") as outfp: - outfp.write(MAIN_INCLUDES) - for m in self.frozen_modules: - outfp.write(f"extern unsigned char {m.c_name}[];\n") - - outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name)) - for m in self.frozen_modules: - outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n') - outfp.write(MAIN_SUFFIX) - if oss: - outfp.write(FAKE_PREFIX) - outfp.write(MAIN_SUFFIX) - - def write_frozen(self, m: FrozenModule, outfp): - """Write a single frozen module's bytecode out to a C variable.""" - outfp.write(f"unsigned char {m.c_name}[] = {{") - for i in range(0, len(m.bytecode), 16): - outfp.write("\n\t") - for c in bytes(m.bytecode[i : i + 16]): - outfp.write(f"{c:d},") - outfp.write("\n};\n") - - def compile_path(self, path: Path, top_package_path: Path): - """Entry point for compiling a Path object.""" - if path.is_dir(): - self.compile_package(path, top_package_path) - else: - self.compile_file(path, top_package_path) - - @indent_msg - def compile_package(self, path: Path, top_package_path: Path): - """Compile all the files within a Python package dir.""" - assert path.is_dir() - if path.name in DENY_LIST: - self.msg(path, "X") - return - - # Python packages are directories that have __init__.py in them. - is_package_dir = any(child.name == "__init__.py" for child in path.iterdir()) - if not is_package_dir: - self.msg(path, "S") - return - - self.msg(path, "P") - # Recursively compile all children in this dir - for child in path.iterdir(): - self.compile_path(child, top_package_path) - - def get_module_qualname(self, file_path: Path, top_package_path: Path) -> list[str]: - # `path` looks like 'Lib/foo/bar/baz.py' - - # chop off 'Lib/' to get something that represents a Python module hierarchy. - # e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz' - normalized_path = file_path.relative_to(top_package_path.parent) - - if normalized_path.name == "__init__.py": - # Special handling for `__init__.py`. In this case, this file - # specifies that the containing directory should be treated as a package. - # For 'foo/bar/baz/__init__.py': - # - The module name is 'baz' - module_basename = normalized_path.parent.name - # - The parent is foo.bar (need to shave off the 'baz') - module_parent = normalized_path.parent.parent.parts - else: - module_basename = normalized_path.stem - module_parent = normalized_path.parent.parts - return list(module_parent) + [module_basename] - - def compile_string(self, file_content: str) -> types.CodeType: - # instead of passing in the real build time path to 'compile', we - # pass in a marker instead. This prevents the build time path being - # leaked to runtime. That path may not be available at runtime. - # Setting the path to a mark make sure it's a hard error rather - # than a flaky error when inspect module tries to retrieve python source - # code during torchscripting. - path_marker = PATH_MARKER - return compile(file_content, path_marker, "exec") - - @indent_msg - def compile_file(self, path: Path, top_package_path: Path): - """ - Compile a Python source file to frozen bytecode. - - Append the result to `self.frozen_modules`. - """ - assert path.is_file() - if path.suffix != ".py": - self.msg(path, "N") - return - - if path.name in DENY_LIST: - self.msg(path, "X") - return - - self.msg(path, "F") - module_qualname = self.get_module_qualname(path, top_package_path) - module_mangled_name = "__".join(module_qualname) - c_name = "M_" + module_mangled_name - - with open(path) as src_file: - co = self.compile_string(src_file.read()) - - bytecode = marshal.dumps(co) - size = len(bytecode) - if path.name == "__init__.py": - # Python packages are signified by negative size. - size = -size - self.frozen_modules.append( - FrozenModule(".".join(module_qualname), c_name, size, bytecode) - ) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Compile py source") - parser.add_argument("paths", nargs="*", help="Paths to freeze.") - parser.add_argument("--verbose", action="store_true", help="Print debug logs") - parser.add_argument( - "--install-dir", "--install_dir", help="Root directory for all output files" - ) - parser.add_argument( - "--oss", - action="store_true", - help="If it's OSS build, add a fake _PyImport_FrozenModules", - ) - parser.add_argument( - "--symbol-name", - "--symbol_name", - help="The name of the frozen module array symbol to generate", - default="_PyImport_FrozenModules_torch", - ) - - args = parser.parse_args() - - f = Freezer(args.verbose) - - for p in args.paths: - path = Path(p) - if path.is_dir() and not Path.exists(path / "__init__.py"): - # this 'top level path p' is a standard directory containing modules, - # not a module itself - # each 'mod' could be a dir containing __init__.py or .py file - # NB: sorted to make sure this is deterministic - for mod in sorted(path.glob("*")): - f.compile_path(mod, mod) - else: - f.compile_path(path, path) - - f.write_bytecode(args.install_dir) - f.write_main(args.install_dir, args.oss, args.symbol_name) - - -if __name__ == "__main__": - main() # pragma: no cover From 0b9fb91f17edfbc51ae36584dcb8350b2d8bb23b Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 16 Jul 2025 15:07:09 -0700 Subject: [PATCH 166/457] [BE] Remove __reduce_deploy__ (#158291) This PR removes the integration point torch.fx had with torch::deploy (and another minor change). Note: This PR has some broken mypy errors, but I believe those should have been in the code base beforehand, and should be fixed in a separate PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/158291 Approved by: https://github.com/albanD ghstack dependencies: #158288, #158290 --- docs/source/conf.py | 1 - ...t-fx_backcompat_function_signatures.expect | 1 - torch/_dynamo/trace_rules.py | 1 - torch/fx/_lazy_graph_module.py | 5 ----- torch/fx/graph_module.py | 21 ------------------- 5 files changed, 29 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index acb2b088af727..a19d6b7102a3e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1082,7 +1082,6 @@ "z3op", "z3str", # torch.fx.graph_module - "reduce_deploy_graph_module", "reduce_graph_module", "reduce_package_graph_module", # torch.fx.node diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index fab0dbd066761..67ed33950249d 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -29,7 +29,6 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode -torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module torch.fx.interpreter.Interpreter.__init__(self, module: torch.nn.modules.module.Module, garbage_collect_values: bool = True, graph: Optional[torch.fx.graph.Graph] = None) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 8005b08d24465..7df18543ddc44 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3468,7 +3468,6 @@ def _module_dir(m: types.ModuleType): "torch._custom_op", "torch._custom_ops", "torch._decomp", - "torch._deploy", "torch._dispatch", "torch._dynamo", "torch._export", diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 377faf327fc9d..83ce51fddd040 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -127,11 +127,6 @@ def _lazy_forward(self, *args, **kwargs): forward = _lazy_forward - # TODO: we should handle __reduce_deploy__ the same way as __reduce_package__, - # or __reduce__ by calling _real_recompile. But I don't find a good way - # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule - # will be used in torch::deploy. So it's skipped for now. - def __reduce_package__(self, exporter: PackageExporter): """ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 2e1a0963f53b6..065cf82983e53 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -30,7 +30,6 @@ __all__ = [ "reduce_graph_module", "reduce_package_graph_module", - "reduce_deploy_graph_module", "GraphModule", ] @@ -147,18 +146,6 @@ def reduce_package_graph_module( return _deserialize_graph_module(forward, body) -@compatibility(is_backward_compatible=True) -def reduce_deploy_graph_module( - importer: PackageImporter, body: dict[Any, Any], import_block: str -) -> torch.nn.Module: - ns = {} - ns["__builtins__"] = importer.patched_builtins - fn_src = body.get("_code") - assert fn_src is not None - forward = _forward_from_src(import_block + fn_src, ns) - return _deserialize_graph_module(forward, body) - - # We create a dummy class here because symbolic_trace pulls the forward() # function off of the class, rather than the instance. This class is used # in _deserialize_graph_module() below. @@ -853,14 +840,6 @@ def call_wrapped(self, *args, **kwargs): # Passing Tracer as argument allows subclasses extending fx.GraphModule # define their own Tracer (extending fx.Tracer). - def __reduce_deploy__(self, importer: Importer): - dict_without_graph = self.__dict__.copy() - dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ - del dict_without_graph["_graph"] - - python_code = self.recompile() - import_block = _format_import_block(python_code.globals, importer) - return (reduce_deploy_graph_module, (dict_without_graph, import_block)) def __reduce_package__(self, exporter: PackageExporter): dict_without_graph = self.__dict__.copy() From d9426a81d2ab54f809a3b32a6ab2e606073fe66f Mon Sep 17 00:00:00 2001 From: PaliC Date: Wed, 16 Jul 2025 15:07:09 -0700 Subject: [PATCH 167/457] [BE] Modify PyObjectSlot the assume only a single interpreter is in use (#158407) This PR makes some less risky changes to PyObjectSlot as there is a lot of stuff we do not need since there is only one interpreter. Specifically `check_interpreter` and `has_pyobj_nonhermetic` are removed Pull Request resolved: https://github.com/pytorch/pytorch/pull/158407 Approved by: https://github.com/albanD ghstack dependencies: #158288, #158290, #158291 --- c10/core/impl/PyObjectSlot.cpp | 14 +------ c10/core/impl/PyObjectSlot.h | 74 +++------------------------------- torch/csrc/Storage.cpp | 11 ----- 3 files changed, 7 insertions(+), 92 deletions(-) diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 400903bc7a651..62af2eae8e37a 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -44,19 +44,7 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { if (interpreter) { return *interpreter; } - TORCH_CHECK( - false, - "cannot access PyObject for Tensor on interpreter ", - (*pyobj_interpreter_.load())->name()); -} - -bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) { - return interpreter == pyobj_interpreter(); -} - -bool PyObjectSlot::has_pyobj_nonhermetic() { - return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true) - .has_value(); + TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set"); } bool PyObjectSlot::owns_pyobj() { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 4b9bcf1e4a1c3..af8b9fa4d0ec7 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -28,48 +28,7 @@ struct C10_API PyObjectSlot { PyInterpreter* self_interpreter, PyObject* pyobj, PyInterpreterStatus status) { - impl::PyInterpreter* expected = nullptr; - switch (status) { - case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: - // caller guarantees there is no multithreaded access; if there is - // no data race OK to do a relaxed store - pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); - break; - case impl::PyInterpreterStatus::TAGGED_BY_US: - // no tagging is necessary, the tag is already correct - break; - case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: - // attempt to claim this TensorImpl with the specified interpreter - // tag - if (pyobj_interpreter_.compare_exchange_strong( - expected, self_interpreter, std::memory_order_acq_rel)) { - break; - } - // test if, actually, it was already tagged by us! this situation can't - // be caused by a race, but it could be caused by a situation - // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED - // (because they didn't pre-check the tag) when actually it was - // owned by the interpreter - if (expected == self_interpreter) { - break; - } - // fallthrough, we lost the race. We are guaranteed not to lose the - // race with ourself, as calls to init_pyobj with the same interpreter - // ID must be sequentialized by the GIL - [[fallthrough]]; - case impl::PyInterpreterStatus::TAGGED_BY_OTHER: - TORCH_CHECK( - false, - "cannot allocate PyObject for Tensor on interpreter ", - self_interpreter, - " that has already been used by another torch deploy interpreter ", - pyobj_interpreter_.load()); - } - - // we are the ONLY thread that can have gotten to this point. It is not - // possible to conflict with another zero interpreter as access is protected - // by GIL - // NB: owns_pyobj tag is initially false + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); pyobj_ = pyobj; } @@ -97,30 +56,16 @@ struct C10_API PyObjectSlot { std::optional check_pyobj( PyInterpreter* self_interpreter, bool ignore_hermetic_tls = false) const { - // Note [Memory ordering on Python interpreter tag] impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter == nullptr) { - // NB: This never returns DEFINITELY_UNINITIALIZED because there is - // always the possibility that another thread races to initialize - // after we query here. The only time when we can conclude a tensor - // is definitely uninitialized is when we have just allocated it and - // it cannot have escaped to other threads yet return std::nullopt; - } else if (interpreter == self_interpreter) { - // NB: pyobj_ could still be null! - if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return std::nullopt; - } else { - return _unchecked_untagged_pyobj(); - } + } + + if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { + return std::nullopt; } else { - TORCH_CHECK( - false, - "cannot access PyObject for Tensor on interpreter ", - (*self_interpreter)->name(), - " that has already been used by another torch deploy interpreter ", - (*pyobj_interpreter_.load())->name()); + return _unchecked_untagged_pyobj(); } } @@ -130,13 +75,6 @@ struct C10_API PyObjectSlot { PyInterpreter& load_pyobj_interpreter() const; - // Check if the PyObjectSlot's interpreter is the same as the specified - // interpreter - bool check_interpreter(PyInterpreter* interpreter); - - // Check if the PyObjectSlot is holding a PyObject, owned or non-owned - bool has_pyobj_nonhermetic(); - bool owns_pyobj(); void set_owns_pyobj(bool b); diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index d566dc666ebfe..cc682a2644af2 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -98,17 +98,6 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - // If the StorageImpl has a PyObject that is managed by a different - // interpreter than the current one, create a new StorageImpl that points to - // the same data and then create the Python storage from that. - // NOTE: This is only supposed to happen in MultiPy // codespell:ignore - if (pyobj_slot->has_pyobj_nonhermetic() && - !pyobj_slot->check_interpreter(getPyInterpreter())) { - return THPStorage_NewWithStorage( - THPStorageClass, - c10::newStorageImplFromRefcountedDataPtr(storage), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); - } std::optional maybe_pyobj = pyobj_slot->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false); c10::impl::PyInterpreterStatus status = From 9636e2cfd3e995ef977f670ad47e8e895296d992 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 16 Jul 2025 18:54:28 -0700 Subject: [PATCH 168/457] Move off of deprecated API in 2.9 (#158527) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158527 Approved by: https://github.com/danielvegamyhre --- torch/_inductor/kernel/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 9a7507631cc49..aed4a03edd186 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -138,7 +138,7 @@ def maybe_realize(args: list[Optional[IRNode]]): def get_float32_precision(): if ( - torch.get_float32_matmul_precision() == "highest" + torch.backends.cuda.matmul.fp32_precision == "ieee" or torch.version.hip or torch.mtia.is_available() ): From 9f37cce69334bccebf4b21503f0047d0c0bb320c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 06:28:49 +0000 Subject: [PATCH 169/457] Revert "[Docker builds] Move from Miniconda to Miniforge (#158370)" This reverts commit 0a99b026d6bd0f67dc2c0a20fe3228ddc4144854. Reverted https://github.com/pytorch/pytorch/pull/158370 on behalf of https://github.com/laithsakka due to this fail pr time benchmarks ([comment](https://github.com/pytorch/pytorch/pull/158370#issuecomment-3082744071)) --- .ci/docker/common/install_conda.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 481de54a50f2c..185837b7e98a2 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -4,8 +4,12 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore - CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" + BASE_URL="https://repo.anaconda.com/miniconda" + CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" + if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]] || [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" + fi MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) @@ -17,6 +21,7 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then exit 1 ;; esac + mkdir -p /opt/conda chown jenkins:jenkins /opt/conda From a38f433be2e94a64b095a44ba39879d02d0c2316 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Thu, 17 Jul 2025 06:33:08 +0000 Subject: [PATCH 170/457] [Docker builds] Move from Miniconda to Miniforge (#158370) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is related to: https://www.anaconda.com/legal/terms/terms-of-service Trying to fix outage with docker builds. https://github.com/pytorch/pytorch/actions/runs/16298993712/job/46033590799 Rocm and XPU builds since they use Miniforge are not affected ``` #22 ERROR: process "/bin/sh -c bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt /opt/conda/requirements-docs.txt" did not complete successfully: exit code: 1 ------ > [base 14/42] RUN bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt /opt/conda/requirements-docs.txt: 11.93 CondaToSNonInteractiveError: Terms of Service have not been accepted for the following channels. Please accept or remove them before proceeding: 11.93 • https://repo.anaconda.com/pkgs/main 11.93 • https://repo.anaconda.com/pkgs/r 11.93 11.93 To accept a channel's Terms of Service, run the following and replace `CHANNEL` with the channel name/URL: 11.93 ‣ conda tos accept --override-channels --channel CHANNEL ``` Hence solution is: 1. using `` conda tos accept --override-channels --channel defaults`` 2. use Miniforge instead of Miniconda. Using solution 2. Solution Tried that don't work: 1. Using ``CONDA_ALWAYS_YES = true `` 4. Using older version of miniconda ``` [Miniconda3-py310_25.5.1-0-Linux-x86_64.sh](https://repo.anaconda.com/miniconda/Miniconda3-py310_25.5.1-0-Linux-x86_64.sh) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158370 Approved by: https://github.com/seemethere Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com> --- .ci/docker/common/install_conda.sh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 185837b7e98a2..481de54a50f2c 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -4,12 +4,8 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then - BASE_URL="https://repo.anaconda.com/miniconda" - CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]] || [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore - CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" - fi + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) @@ -21,7 +17,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then exit 1 ;; esac - mkdir -p /opt/conda chown jenkins:jenkins /opt/conda From 09db3a22e8783c4841697317688ba9467c7cc457 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 16 Jul 2025 11:16:19 -0700 Subject: [PATCH 171/457] [BE] Get rid of final mentions of BUILD_SPLIT_CUDA (#158453) BUILD_SPLIT_CUDA logic has been removed for a while Differential Revision: [D78418191](https://our.internmc.facebook.com/intern/diff/D78418191/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158453 Approved by: https://github.com/albanD ghstack dependencies: #158358, #158365 --- .ci/pytorch/windows/internal/smoke_test.bat | 9 +------ .../ATen/templates/RegisterDispatchKey.cpp | 4 +--- cmake/Summary.cmake | 1 - torch/headeronly/macros/Export.h | 24 ++++++------------- 4 files changed, 9 insertions(+), 29 deletions(-) diff --git a/.ci/pytorch/windows/internal/smoke_test.bat b/.ci/pytorch/windows/internal/smoke_test.bat index b7463f855428f..f671a9d0e0abb 100644 --- a/.ci/pytorch/windows/internal/smoke_test.bat +++ b/.ci/pytorch/windows/internal/smoke_test.bat @@ -148,14 +148,7 @@ if "%NVIDIA_GPU_EXISTS%" == "0" ( goto end ) -set BUILD_SPLIT_CUDA= -if exist "%install_root%\lib\torch_cuda_cu.lib" if exist "%install_root%\lib\torch_cuda_cpp.lib" set BUILD_SPLIT_CUDA=ON - -if "%BUILD_SPLIT_CUDA%" == "ON" ( - cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda_cu.lib torch_cuda_cpp.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ /INCLUDE:?_torch_cuda_cu_linker_symbol_op_cuda@native@at@@YA?AVTensor@2@AEBV32@@Z -) else ( - cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ -) +cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ .\check-torch-cuda.exe if ERRORLEVEL 1 exit /b 1 diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 968729f85267c..158277dd5d53b 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -9,9 +9,7 @@ #if defined(CAFFE2_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ defined(TORCH_HIP_BUILD_MAIN_LIB) || \ - defined(TORCH_XPU_BUILD_MAIN_LIB) || \ - defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ - defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) + defined(TORCH_XPU_BUILD_MAIN_LIB) #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #endif diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index a80365f353dfd..3c2ec74f14d17 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -70,7 +70,6 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_CPP_CODE_COVERAGE : ${USE_CPP_CODE_COVERAGE}") message(STATUS " USE_CUDA : ${USE_CUDA}") if(${USE_CUDA}) - message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") diff --git a/torch/headeronly/macros/Export.h b/torch/headeronly/macros/Export.h index 8c4e207d0dada..8dd25419efb4e 100644 --- a/torch/headeronly/macros/Export.h +++ b/torch/headeronly/macros/Export.h @@ -100,8 +100,12 @@ #define TORCH_API C10_IMPORT #endif -// You may be wondering: Whose brilliant idea was it to split torch_cuda into -// two pieces with confusing names? +// You may be wondering why we have TORCH_CUDA_CPP_API and TORCH_CUDA_CU_API +// belonging to the same library instead of just one TORCH_CUDA_API. Well, it +// can indeed just be one TORCH_CUDA_API (and used to be)! TORCH_CUDA_CPP_API +// and TORCH_CUDA_CU_API are artifacts of when we needed a split build to +// avoid relocation marker linking errors. The context is as follows: +// // Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we // tried to compile PyTorch for CUDA 11.1, which ran into relocation marker // issues when linking big binaries. @@ -116,26 +120,12 @@ // relocation marker issues, we could link our static libraries to a smaller // part of torch_cuda (torch_cuda_cpp) and avoid the issues. -// libtorch_cuda_cu.so -#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB -#define TORCH_CUDA_CU_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CU_API C10_IMPORT -#endif - -// libtorch_cuda_cpp.so -#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB -#define TORCH_CUDA_CPP_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CPP_API C10_IMPORT -#endif - // libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the // same api) #ifdef TORCH_CUDA_BUILD_MAIN_LIB #define TORCH_CUDA_CPP_API C10_EXPORT #define TORCH_CUDA_CU_API C10_EXPORT -#elif !defined(BUILD_SPLIT_CUDA) +#else #define TORCH_CUDA_CPP_API C10_IMPORT #define TORCH_CUDA_CU_API C10_IMPORT #endif From 04349f9ee541c7d07cc057bbe739f46bd4c30dcc Mon Sep 17 00:00:00 2001 From: Kevin Fu Date: Thu, 17 Jul 2025 06:47:43 +0000 Subject: [PATCH 172/457] [PT2]: Skip AOTI Weight Loading during Init (#158416) Summary: AOTI already has weights embedded in .so file. So for the initial load, no need to load the weights again. This allows lowered modules can have different set of weights on different hardwares. Test Plan: ``` MODEL_TYPE=ads_mtml_offsite_cvr_oba_optout_dedicated_model MODEL_ENTITY_ID=895279202 SNAPSHOT_ID=0 MODULE=merge buck2 run mode/dev-nosan -c fbcode.nvcc_arch=a100,h100 -c fbcode.enable_gpu_sections=true fbcode//caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.disagg.gpu.${MODULE} --moduleName ${MODULE} --predictor-hardware-type 1 --submodToDevice "" --benchmarkDontRebatchSamples=true --benchmarkNumIterations 1000 ``` Rollback Plan: Differential Revision: D78383881 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158416 Approved by: https://github.com/henryoier, https://github.com/SherlockNoMad --- torch/nativert/executor/DelegateExecutor.h | 2 ++ torch/nativert/executor/Executor.cpp | 17 ++++++++++++++--- torch/nativert/executor/Executor.h | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/torch/nativert/executor/DelegateExecutor.h b/torch/nativert/executor/DelegateExecutor.h index b8c3d506c4313..7d88f98987764 100644 --- a/torch/nativert/executor/DelegateExecutor.h +++ b/torch/nativert/executor/DelegateExecutor.h @@ -46,6 +46,8 @@ class DelegateExecutor { // This call activate the processed weights. virtual void commitWeights() = 0; + virtual void initWeights(std::shared_ptr weights) = 0; + virtual std::vector run(std::vector& inputs) = 0; }; diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index a90b93bd17c7d..01ce24636fb77 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -72,9 +72,7 @@ void Executor::initialize( delegateExecutors_ = std::move(executionKernels.delegateExecutors); constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions); - // initialize weights_ - processWeights(weights); - atomicSwapWeights(weights); + initWeights(weights); if (executorConfig_.layoutPlannerSettings.enabled()) { layoutPlanner_ = std::make_unique( @@ -142,6 +140,19 @@ void Executor::processWeights(const std::shared_ptr& weights) { } } +void Executor::initWeights(const std::shared_ptr& weights) { + maybeRunConstantFolding(weights); + if (constantFolder_.has_value()) { + constantFolder_->evaluate(*weights); + } + + weights_.withLock([&](auto& w) { w = std::move(weights); }); + + for (auto& delegateExecutor : delegateExecutors_) { + delegateExecutor->initWeights(weights); + } +} + namespace { void validateInput( const std::string& inputName, diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index 3ab206b01e0c1..cd15e846a3c96 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -177,6 +177,8 @@ class Executor { // Helper method to get current timestamp in seconds int64_t getCurrentTimestampSeconds() const; + void initWeights(const std::shared_ptr& weights); + std::unique_ptr graphExecutor_; const Placement placement_; From d76323d41742cbc05ec6857319b267d2c7ea8fd9 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 17 Jul 2025 06:48:21 +0000 Subject: [PATCH 173/457] [NativeRT] Remove normalizeDevice (#158489) Summary: In pytorch, tensor.to("cuda") behaves differently from tensor.to("cuda:0). tensor.to("cuda") will read from thread local DeviceGuard, aka cuda::current_device(), to infer the device index. TBEPermute is relying on this behavior to route output tensor to a device specified by current thread. For this reason, we remove the normalizeDevice(), and disallow index-less cuda device in Placement. Device-to-device mapping must be done between concrete device! Test Plan: CI Rollback Plan: Differential Revision: D78443109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158489 Approved by: https://github.com/henryoier --- test/cpp/nativert/test_placement.cpp | 21 ++------------------- torch/nativert/executor/Placement.cpp | 21 +++++++++++++++++---- torch/nativert/executor/Placement.h | 15 --------------- torch/nativert/executor/PlacementUtils.cpp | 14 -------------- 4 files changed, 19 insertions(+), 52 deletions(-) diff --git a/test/cpp/nativert/test_placement.cpp b/test/cpp/nativert/test_placement.cpp index e88ae20e1de04..ab65bfc07b917 100644 --- a/test/cpp/nativert/test_placement.cpp +++ b/test/cpp/nativert/test_placement.cpp @@ -8,23 +8,6 @@ using namespace ::testing; namespace torch::nativert { -TEST(PlacementTest, NormalizeDevice) { - c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); - c10::Device cpuDevice1 = c10::Device(c10::DeviceType::CPU); - cpuDevice1.set_index(1); - - EXPECT_EQ(normalizeDevice(cpuDevice), cpuDevice); - EXPECT_NE(normalizeDevice(cpuDevice1), cpuDevice1); - - c10::Device cudaDevice = c10::Device(c10::DeviceType::CUDA); - c10::Device cudaDevice1 = c10::Device(c10::DeviceType::CUDA, 1); - EXPECT_EQ(normalizeDevice(cudaDevice), c10::Device(c10::DeviceType::CUDA, 0)); - EXPECT_EQ( - normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 1)); - - EXPECT_NE( - normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 0)); -} TEST(PlacementTest, IsSameDevice) { c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); @@ -90,11 +73,11 @@ TEST(PlacementTest, Placement) { {c10::Device("cuda:0"), c10::Device("cuda:1")}}; Placement p1(deviceMap1); EXPECT_EQ(p1.getMappedDevice(c10::Device("cpu")), c10::Device("cpu")); - EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda:1")); + EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda")); EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:1")); std::unordered_map deviceMap2 = { - {c10::Device("cpu"), c10::Device("cuda")}}; + {c10::Device("cpu"), c10::Device("cuda:0")}}; Placement p2(deviceMap2); EXPECT_EQ(p2.getMappedDevice(c10::Device("cpu")), c10::Device("cuda:0")); EXPECT_EQ(p2.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:0")); diff --git a/torch/nativert/executor/Placement.cpp b/torch/nativert/executor/Placement.cpp index be8b6e6df9669..0432ecdc2a7c3 100644 --- a/torch/nativert/executor/Placement.cpp +++ b/torch/nativert/executor/Placement.cpp @@ -32,6 +32,15 @@ std::ostream& operator<<(std::ostream& os, const Placement& placement) { return os; } +namespace { +void assertCudaDeviceHasIndex(const c10::Device& device) { + if (device.is_cuda()) { + TORCH_CHECK( + device.has_index(), "CUDA device in placement must have an index"); + } +} +} // namespace + Placement::Placement(std::optional defaultDevice) : Placement({}, defaultDevice) {} @@ -39,16 +48,20 @@ Placement::Placement( const std::unordered_map& deviceMap, std::optional defaultDevice) { for (const auto& [srcDevice, dstDevice] : deviceMap) { - deviceMap_.try_emplace( - normalizeDevice(srcDevice), normalizeDevice(dstDevice)); + assertCudaDeviceHasIndex(srcDevice); + assertCudaDeviceHasIndex(dstDevice); + + deviceMap_.try_emplace(srcDevice, dstDevice); } + if (defaultDevice.has_value()) { - defaultDevice_ = normalizeDevice(defaultDevice.value()); + assertCudaDeviceHasIndex(defaultDevice.value()); + defaultDevice_ = defaultDevice.value(); } } c10::Device Placement::getMappedDevice(const c10::Device& srcDevice) const { - auto it = deviceMap_.find(normalizeDevice(srcDevice)); + auto it = deviceMap_.find(srcDevice); if (it != deviceMap_.end()) { return it->second; } diff --git a/torch/nativert/executor/Placement.h b/torch/nativert/executor/Placement.h index 9f9a2c627d258..6ea86348973ee 100644 --- a/torch/nativert/executor/Placement.h +++ b/torch/nativert/executor/Placement.h @@ -8,21 +8,6 @@ namespace torch::nativert { -/** - * This function returns a normalized version of the input device: - * - For CPU devices, the returned device will have no index (i.e., the default - * CPU device). - * - For CUDA devices, if no index is specified, index 0 is assumed. - * - For other device types, the function will raise an error. - * - * @param device The input c10::Device to normalize. - * @return A normalized c10::Device with standardized indexing. - * - * @throws c10::Error If the device type is not CPU or CUDA. - */ - -c10::Device normalizeDevice(const c10::Device& device); - /** * Returns true if the two devices are the same and has the same device index * (if cuda). diff --git a/torch/nativert/executor/PlacementUtils.cpp b/torch/nativert/executor/PlacementUtils.cpp index 988c9997ed037..e73224b4f4f52 100644 --- a/torch/nativert/executor/PlacementUtils.cpp +++ b/torch/nativert/executor/PlacementUtils.cpp @@ -4,20 +4,6 @@ namespace torch::nativert { -c10::Device normalizeDevice(const c10::Device& device) { - // cpu device doesn't have index - // cuda device index must have a index - if (device.is_cpu()) { - return c10::Device(c10::DeviceType::CPU); - } else if (device.is_cuda()) { - return c10::Device( - c10::DeviceType::CUDA, - device.has_index() ? device.index() : static_cast(0)); - } else { - TORCH_CHECK(false, "Unsupported device type", device); - } -} - bool isSameDevice(const c10::Device& a, const c10::Device& b) { if (a.is_cpu()) { return b.is_cpu(); From 39ac189808c61588f3594dbc2fc1d69bb6194c47 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 17 Jul 2025 08:19:55 +0000 Subject: [PATCH 174/457] Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037) cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack. The scaling format is still detected from the sizes of the scale tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037 Approved by: https://github.com/eqy, https://github.com/drisspg --- aten/src/ATen/ceil_div.h | 17 +- aten/src/ATen/cuda/CUDABlas.cpp | 116 +++++++--- aten/src/ATen/cuda/CUDABlas.h | 14 +- aten/src/ATen/cuda/tunable/GemmCommon.h | 8 +- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 63 ++++-- aten/src/ATen/cuda/tunable/TunableGemm.h | 5 +- aten/src/ATen/native/cuda/Blas.cpp | 243 +++++++++++---------- test/test_matmul_cuda.py | 101 +++++++-- 8 files changed, 363 insertions(+), 204 deletions(-) diff --git a/aten/src/ATen/ceil_div.h b/aten/src/ATen/ceil_div.h index 37d67b232a22c..777fc09a7049d 100644 --- a/aten/src/ATen/ceil_div.h +++ b/aten/src/ATen/ceil_div.h @@ -7,8 +7,15 @@ namespace at { /** Computes ceil(a / b) */ -template >> -C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { +template < + typename Res = void, + typename T, + typename U, + typename = std::enable_if_t< + std::conjunction_v, std::is_integral>>> +C10_ALWAYS_INLINE C10_HOST_DEVICE + std::conditional_t, std::common_type_t, Res> + ceil_div(T a, U b) { return (a + b - 1) / b; } @@ -16,8 +23,10 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest multiple of b */ -template -C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { +template +C10_ALWAYS_INLINE C10_HOST_DEVICE + std::conditional_t, std::common_type_t, Res> + round_up(T a, U b) { return ceil_div(a, b) * b; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d009520d05ab8..acb1d5ed8b0da 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1843,6 +1843,69 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); +int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) { + switch (scaling_type) { + case ScalingType::BlockWise1x32: + TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + + case ScalingType::BlockWise1x16: + TORCH_CHECK(scale_dtype == kFloat8_e4m3fn); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + + case ScalingType::RowWise: + TORCH_CHECK(scale_dtype == kFloat); +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F; +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // Return the default, since in old hipblaslt this is activated via + // the SCALE_POINTER_VEC_EXT attributed. + return 0; +#else + TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + + case ScalingType::BlockWise1x128: + TORCH_CHECK(scale_dtype == kFloat); + TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling") +#if CUDA_VERSION >= 12090 + return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; +#else + TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + + case ScalingType::BlockWise128x128: + TORCH_CHECK(scale_dtype == kFloat); + TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling") +#if CUDA_VERSION >= 12090 + return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; +#else + TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + +case ScalingType::TensorWise: + TORCH_CHECK(scale_dtype == kFloat); +#if CUDA_VERSION >= 12080 + return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; +#else + // The macro isn't defined, thus we inline its value. + return 0; +#endif // if CUDA_VERSION >= 12080 + + default: + TORCH_CHECK(false); + return -1; + } +} + void scaled_gemm( char transa, char transb, @@ -1854,19 +1917,20 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, + ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, + ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum, - bool use_rowwise) { + bool use_fast_accum) { // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. #if CUDA_VERSION >= 11080 || defined(USE_ROCM) @@ -1879,19 +1943,15 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; -#if defined(USE_ROCM) -#if defined(HIPBLASLT_OUTER_VEC) - // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F -#elif defined(HIPBLASLT_VEC_EXT) - if (use_rowwise) { + // hipblaslt supported row-wise before cublas, and did so their own way (via + // the SCALE_POINTERSs), but then migrated to match how cublas does it (via + // the SCALE_MODEs). Here we check for this early custom mode. +#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) + if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } -#else - // rowwise isn't supported using older hipblaslt - TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); -#endif -#endif // defined(USE_ROCM) +#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { @@ -1931,30 +1991,14 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } - if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { -#if CUDA_VERSION >= 12080 - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - } else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) { -#if CUDA_VERSION >= 12080 - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { -#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) - // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - } + // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, + // but we must invoke get_scale_mode anyways to trigger the version checks. + int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); + int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); +#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); +#endif CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index b1dac2162dc42..5021917fe0950 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -136,6 +136,15 @@ void int8_gemm( int32_t* result_ptr, int64_t result_ld); +enum class ScalingType : std::uint8_t { + TensorWise, // fp32 scales + RowWise, // fp32 scales + BlockWise1x16, // fp8_e4m3fn scales + BlockWise1x32, // fp8_e8m0fnu scales + BlockWise1x128, // fp32 scales + BlockWise128x128, // fp32 scales +}; + void scaled_gemm( char transa, char transb, @@ -147,19 +156,20 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, + ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, + ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum, - bool use_rowwise); + bool use_fast_accum); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 6f896f1a22bfc..6d19907aba4ad 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -29,6 +29,8 @@ namespace at::cuda::tunable { +using at::cuda::blas::ScalingType; + enum class BlasOp { N = 0, T = 1 @@ -598,7 +600,8 @@ struct ScaledGemmParams : OpParams { // // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", - transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, + transa, transb, m, n, k, lda, ldb, ldc, + a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise, bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); } @@ -673,11 +676,13 @@ struct ScaledGemmParams : OpParams { int64_t lda{}; ScalarType a_dtype{}; ScalarType a_scale_dtype{}; + ScalingType a_scaling_type{}; const void* b{}; const void* b_scale_ptr{}; int64_t ldb{}; ScalarType b_dtype{}; ScalarType b_scale_dtype{}; + ScalingType b_scaling_type{}; const void* bias_ptr{}; ScalarType bias_dtype{}; void* c{}; @@ -686,7 +691,6 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype{}; void* amax_ptr{}; bool use_fast_accum{}; - bool use_rowwise{}; private: bool duplicate_inputs_{false}; }; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 32fb7c2774fff..809ba51009f0a 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -206,23 +206,43 @@ float GetBetaFromParams(const ScaledGemmParams* params) { } template -bool GetUseRowwiseFromParams(const GemmParams* params) { - return false; +ScalingType GetAScalingTypeFromParams(const GemmParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { - return false; +ScalingType GetBScalingTypeFromParams(const GemmParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { - return false; +ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams* params) { + return ScalingType::TensorWise; } template -bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { - return params->use_rowwise; +ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams* params) { + return ScalingType::TensorWise; +} + +template +ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { + return params->a_scaling_type; +} + +template +ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { + return params->b_scaling_type; } template @@ -489,23 +509,24 @@ class HipblasltGemmOp : public Callable { const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { -#ifdef HIPBLASLT_VEC_EXT - if (GetUseRowwiseFromParams(params)) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); - } - else + hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; + hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; + if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { +#if defined(HIPBLASLT_OUTER_VEC) + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(HIPBLASLT_VEC_EXT) + a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; #endif - { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } -#ifdef HIPBLASLT_OUTER_VEC - if (GetUseRowwiseFromParams(params)) { - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + if (GetBScalingTypeFromParams(params) == ScalingType::RowWise) { +#if defined(HIPBLASLT_OUTER_VEC) matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); - } +#elif defined(HIPBLASLT_VEC_EXT) + b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; #endif + } + matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); + matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d7e2835b1b109..d941c230630c4 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -96,19 +96,20 @@ class DefaultScaledGemmOp : public Callable> { params->lda, params->a_dtype, params->a_scale_dtype, + params->a_scaling_type, params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->b_scale_dtype, + params->b_scaling_type, params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, - params->use_fast_accum, - params->use_rowwise); + params->use_fast_accum); return OK; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index c46e1cc633119..60becebfb81e5 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -99,6 +100,7 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } } +using at::cuda::blas::ScalingType; /** * @brief Prepares matrices for CUBLAS operation @@ -140,7 +142,9 @@ struct cublasCommonArgs { Tensor& c, const std::optional& scale_a = std::nullopt, const std::optional& scale_b = std::nullopt, - const std::optional& scale_result = std::nullopt) { + const std::optional& scale_result = std::nullopt, + const std::optional& scaling_choice_a = std::nullopt, + const std::optional& scaling_choice_b = std::nullopt) { bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); @@ -152,8 +156,10 @@ struct cublasCommonArgs { // as B.T @ A.T, check transpose_result to determine if we flip the scales scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); + scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); + scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; } if (scale_result) { @@ -199,7 +205,9 @@ struct cublasCommonArgs { void* scale_matb_ptr = nullptr; void* scale_result_ptr = nullptr; std::optional scale_mata_dtype; + std::optional scaling_mata_type; std::optional scale_matb_dtype; + std::optional scaling_matb_type; std::optional scale_result_dtype; }; } // namespace @@ -1075,133 +1083,114 @@ static bool _scaled_mm_is_fnuz() { namespace{ -enum class ScalingType : std::uint8_t { - TensorWise, - RowWise, - BlockWise, - Error -}; /* * Scaling Type Determination: * --------------------------- * Conditions and corresponding Scaling Types: * - * - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`: + * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: * - Returns BlockWise (with additional size checks). * - * - If scale_a.numel() == 1 && scale_b.numel() == 1: + * - Else if scale.numel() == 1: * - Returns TensorWise. * - * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1: * - Returns RowWise. * + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128: + * - Returns BlockWise 1x128. + * + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128: + * - Returns BlockWise 128x128. + * * - Otherwise: * - Returns Error. */ -// Validates the scale tensors to scaled_mm -// And returns the type of scaling/which kernel to use -ScalingType get_scaling_type( - const at::Tensor& scale_a, - const at::Tensor& scale_b, - int64_t dim_m, - int64_t dim_k, - int64_t dim_n) { - // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types) - if ((scale_a.scalar_type() == scale_b.scalar_type()) && - ((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) { - const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn; - - // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements - // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements. - const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32; - - constexpr int64_t BLOCK_SIZE_MN = 128; - - // adjust for fp4x2 packing if necessary - const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k; - - auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; - auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K); - auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4; - - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - - // Check expected sizes for block-wise scaling - auto expected_a_size = - BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks; - auto expected_b_size = - BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks; - - TORCH_CHECK(scale_a.numel() == expected_a_size, - "For BlockWise scaling: Expected scale_a size to be ", - expected_a_size, " but got ", scale_a.numel()); - TORCH_CHECK(scale_b.numel() == expected_b_size, - "For BlockWise scaling: Expected scale_b size to be ", - expected_b_size, " but got ", scale_b.numel()); - - TORCH_CHECK( - scale_a.is_contiguous() && scale_b.is_contiguous(), - "For BlockWise scaling: Both scale_a and scale_b must be contiguous"); - - return ScalingType::BlockWise; - } - // Both Per-Tensor and Row-wise scaling expect fp32 tensors - TORCH_CHECK( - scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, - "Both scale_a and scale_b must be float (fp32) tensors."); +using at::cuda::blas::ScalingType; - // Check the singluar scale case for per-tensor scaling - if (scale_a.numel() == 1 && scale_b.numel() == 1) { - return ScalingType::TensorWise; - } +bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1; +} - // For non-TensorWise scaling, enforce 2D input tensors - TORCH_CHECK( - scale_a.dim() == 2 && scale_b.dim() == 2, - "For non-TensorWise scaling, scale tensors must be 2-dimensional, " - "but got scale_a.dim()=", - scale_a.dim(), - " and scale_b.dim()=", - scale_b.dim()); - - // Check for RowWise scaling - if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && - scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { -#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ - (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) - TORCH_CHECK( - scale_a.is_contiguous() && scale_b.is_contiguous(), - "Both scale_a and scale_b must be contiguous for RowWise scaling."); - return ScalingType::RowWise; -#else - TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); - return ScalingType::Error; -#endif +bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == t.size(0) && scale.size(1) == 1 + && scale.is_contiguous()); +} + +// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales +bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) { + // Multiply t.size(1) by 2 to adjust for fp4x2 packing + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4) + && scale.is_contiguous()); +} + +// 1x32 blocks for microscaled fp8 data and fp8_e8m0fnu scales +bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) { + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu + && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4) + && scale.is_contiguous()); +} + +bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128) + && scale.stride(0) == 1 && scale.stride(1) == t.size(0)); +} + +bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) { + return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 + && scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128) + && scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1); +} + +bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) { + switch (desired_scaling) { + case ScalingType::TensorWise: + return is_tensorwise_scaling(t, scale); + case ScalingType::RowWise: + return is_rowwise_scaling(t, scale); + case ScalingType::BlockWise1x16: + return is_blockwise_1x16_scaling(t, scale); + case ScalingType::BlockWise1x32: + return is_blockwise_1x32_scaling(t, scale); + case ScalingType::BlockWise1x128: + return is_blockwise_1x128_scaling(t, scale); + case ScalingType::BlockWise128x128: + return is_blockwise_128x128_scaling(t, scale); + default: + TORCH_CHECK(false); + return false; } +} - // If we reach here, the input doesn't match any valid scaling type +std::pair get_joint_scaling( + std::initializer_list> options, + const at::Tensor& a, const at::Tensor& b, + const at::Tensor& scale_a, const at::Tensor& scale_b) { + for (auto [lhs, rhs] : options) { + if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) { + return {lhs, rhs}; + } + } TORCH_CHECK( - false, - "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " - "For RowWise scaling, scale_a should be (", - dim_m, - ", 1) and scale_b should be (1, ", - dim_n, - "). " - "Got scale_a.size()=(", - scale_a.size(0), - ", ", - scale_a.size(1), - ") and ", - "scale_b.size()=(", - scale_b.size(0), - ", ", - scale_b.size(1), - ")"); - - return ScalingType::Error; + false, + "Invalid scaling configuration.\n" + "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n" + "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n" + "- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n" + "- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n" + "- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n" + "- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n" + "Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ", + "b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides() + ); } } // namespace @@ -1219,8 +1208,8 @@ ScalingType get_scaling_type( // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type -// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme +// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme // - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor @@ -1243,9 +1232,21 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - // Check what type of scaling we are doing based on inputs - ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1)); - TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + // Check what type of scaling we are doing based on inputs. This list is sorted + // by decreasing priority. We prefer "simpler" schemes as they are supported + // more broadly (more GPU archs, more CUDA versions) and because they are more + // efficient. This tends to matter only for small matmuls (e.g., 1x1x128). + auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( + { + std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), + std::make_pair(ScalingType::RowWise, ScalingType::RowWise), + std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128), + std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128), + std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128), + std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32), + std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16) + }, + mat1, mat2, scale_a, scale_b); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -1316,7 +1317,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, #ifndef USE_ROCM // We are doing row-wise scaling auto dprops = at::cuda::getCurrentDeviceProperties(); - if (scaling_choice == ScalingType::RowWise + if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise && (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( @@ -1330,7 +1331,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } #else - if (scaling_choice == ScalingType::RowWise) { + if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. Tensor b = mat2; if (_scaled_mm_is_fnuz()) { @@ -1345,7 +1346,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } #endif - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result); + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1422,10 +1423,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.a_scale_ptr = args.scale_mata_ptr; params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); + params.a_scale_dtype = args.scale_mata_dtype.value(); + params.a_scaling_type = args.scaling_mata_type.value(); params.b = args.matb->data_ptr(); params.b_scale_ptr = args.scale_matb_ptr; params.ldb = args.ldb; params.b_dtype = args.matb->scalar_type(); + params.b_scale_dtype = args.scale_matb_dtype.value(); + params.b_scaling_type = args.scaling_matb_type.value(); params.bias_ptr = bias ? bias->data_ptr(): nullptr; params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; params.c = args.result->data_ptr(); @@ -1433,7 +1438,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.ldc = args.result_ld; params.c_dtype = out_dtype_; params.use_fast_accum = use_fast_accum; - params.use_rowwise = scaling_choice == ScalingType::RowWise; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) } @@ -1467,19 +1471,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.lda, args.mata->scalar_type(), args.scale_mata_dtype.value(), + args.scaling_mata_type.value(), args.matb->data_ptr(), args.scale_matb_ptr, args.ldb, args.matb->scalar_type(), args.scale_matb_dtype.value(), + args.scaling_matb_type.value(), bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), args.scale_result_ptr, args.result_ld, out_dtype_, - use_fast_accum, - scaling_choice == ScalingType::RowWise); + use_fast_accum); } return out; diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 31f36681bc3a4..30526c2a84826 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -785,7 +785,7 @@ def amax_to_scale( if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -806,6 +806,20 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): return amax_to_scale(amax, float8_dtype, x.dtype) +def tensor_to_scale_block( + x: torch.Tensor, + float8_dtype: torch.dtype, + block_outer: int, + block_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) + amax = x.abs().amax(dim=[1, 3], keepdim=True).float() + scale = torch.finfo(float8_dtype).max / amax + x = x.mul(scale).to(float8_dtype) + x = x.flatten(2, 3).flatten(0, 1) + scale = scale.flatten(2, 3).flatten(0, 1) + return x, scale + def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -814,6 +828,17 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) +def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: + x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1)) + y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1)) + x_fp32 = x.to(torch.float) / x_scale[:, None, :, None] + y_fp32 = y.to(torch.float) / y_scale[:, None, :, None] + x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1) + y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1) + out_fp32 = torch.mm(x_fp32, y_fp32) + + return out_fp32.to(out_dtype) + def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -1237,11 +1262,7 @@ def test_float8_error_messages(self, device) -> None: y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( - RuntimeError, - re.escape( - "For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1252,11 +1273,7 @@ def test_float8_error_messages(self, device) -> None: ) with self.assertRaisesRegex( - RuntimeError, - re.escape( - " For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1266,22 +1283,18 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, - re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N), device="cuda"), + scale_b=torch.ones((N, N, 1), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, - re.escape( - "Both scale_a and scale_b must be contiguous for RowWise scaling." - ), + RuntimeError, re.escape("Invalid scaling configuration") ): torch._scaled_mm( x_fp8, @@ -1346,6 +1359,58 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 9), + "cuBLAS blockwise scaling added in CUDA 12.9", + ) + @parametrize("output_dtype", [torch.bfloat16, torch.float32]) + @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) + def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block): + torch.manual_seed(42) + + x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3) + y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3) + + x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128) + y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128) + + # 1x128 blocks need scales to be outer-dim-major + if lhs_block == 1: + x_scales = x_scales.t().contiguous().t() + if rhs_block == 1: + y_scales = y_scales.t().contiguous().t() + + # Calculate actual F8 mm + out_scaled_mm = mm_float8( + x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype + ) + + # Calculate emulated F8 mm + out_emulated = mm_float8_emulated_block( + x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype + ) + + cosine_sim = torch.nn.functional.cosine_similarity( + out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0 + ) + self.assertGreaterEqual(float(cosine_sim), 0.999) + + if output_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 6e-1, 7e-2 + else: + atol, rtol = 7e-1, 2e-3 + + self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + + # One last check against the full-precision reference, to ensure we + # didn't mess up the scaling itself and made the test trivial. + cosine_sim = torch.nn.functional.cosine_similarity( + out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0 + ) + self.assertGreaterEqual(float(cosine_sim), 0.999) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) From f4d8bc46c7706f872abcb4ec41f0b32207d5d826 Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Thu, 17 Jul 2025 01:27:38 +0000 Subject: [PATCH 175/457] Enable TF32 as fp32 internal precision for matmul/linear/conv (#157520) ### Description This PR is to enable TF32 as fp32 internal precision for matmul/linear/conv in `mkldnn backend`. Since we have refined fp32 precision API in https://github.com/pytorch/pytorch/pull/125888, we can easily extend the API to support TF32 for `mkldnn backend`. ``` torch.backends.mkldnn.matmul.fp32_precision = 'tf32' torch.backends.mkldnn.conv.fp32_precision = "tf32" ``` Related kernel update and UTs update are done. And the wrapper `bf32_on_and _off` is updated to `reduced_f32_on_and_off`, and it can run tests 3 times, one is reduced_f32 OFF, the other two are reduced_f32 ON (including `bf32 ON` and `tf32 ON`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/157520 Approved by: https://github.com/mingfeima, https://github.com/jansel --- aten/src/ATen/Context.cpp | 9 +- aten/src/ATen/native/CPUBlas.cpp | 2 +- aten/src/ATen/native/mkldnn/Conv.cpp | 39 +++++++ aten/src/ATen/native/mkldnn/Linear.cpp | 12 ++ aten/src/ATen/native/mkldnn/Matmul.cpp | 114 ++++++++++--------- aten/src/ATen/native/mkldnn/Matmul.h | 7 +- docs/source/notes/mkldnn.rst | 7 ++ test/inductor/test_mkldnn_pattern_matcher.py | 50 +++++--- test/test_linalg.py | 28 ++--- test/test_mkldnn.py | 36 +++--- test/test_nn.py | 12 +- test/test_torch.py | 10 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 21 ++-- torch/testing/_internal/common_mkldnn.py | 60 +++++++--- 14 files changed, 266 insertions(+), 141 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 08a834e0a8d4a..8c84779f472d7 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -27,7 +27,7 @@ namespace { These const variables defined the fp32 precisions for different backend We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means - IEEE standard floating point format "tf32" and "bf16" means we are allowed to + IEEE standard floating point format, "tf32" and "bf16" means we are allowed to use "tf32" or "bf16" as internal computation data types for fp32 computations. And "none" means it is override-able by parent's node @@ -40,7 +40,7 @@ namespace { */ const std::map> _fp32_precisions = { {"generic", {{"ieee", "tf32", "bf16", "none"}}}, - {"mkldnn", {{"ieee", "bf16", "none"}}}, + {"mkldnn", {{"ieee", "tf32", "bf16", "none"}}}, {"cuda", {{"ieee", "tf32", "none"}}}}; // Check whether the backend and op are legal @@ -370,6 +370,9 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const { invalid = invalid || (float32Precision("mkldnn", "matmul") == "bf16" && float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM); + invalid = invalid || + (float32Precision("mkldnn", "matmul") == "tf32" && + float32_matmul_precision != at::Float32MatmulPrecision::HIGH); TORCH_CHECK( !invalid, "PyTorch is checking the matmul precision without a specific backend name,", @@ -403,7 +406,7 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { } else if (s_ == "high") { float32_matmul_precision = at::Float32MatmulPrecision::HIGH; setFloat32Precision("cuda", "matmul", "tf32"); - setFloat32Precision("mkldnn", "matmul", "ieee"); + setFloat32Precision("mkldnn", "matmul", "tf32"); return true; } else if (s_ == "medium") { float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 1e2e664fc030d..79dbe7353e159 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -202,7 +202,7 @@ void gemm( float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_MKLDNN_ENABLED() - if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { + if (mkldnn_reduced_f32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { return; } #endif diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 1e2993e79f4d7..8222304e6d072 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -160,6 +160,10 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ mkldnn_bf16_device_check(); } +static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ + return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" && + cpuinfo_has_x86_amx_fp16(); +} static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { auto memory_format = at::MemoryFormat::Contiguous; @@ -271,6 +275,10 @@ static Tensor _mkldnn_convolution( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } _mkldnn_convolution_out( input_t, weight_t, @@ -455,6 +463,9 @@ Tensor mkldnn_convolution_pointwise_binary( if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); @@ -597,6 +608,10 @@ Tensor& mkldnn_convolution_pointwise_binary_( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } _mkldnn_convolution_out( input_t, weight_t, @@ -718,6 +733,9 @@ Tensor _mkldnn_convolution_transpose( if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true); @@ -808,6 +826,10 @@ Tensor mkldnn_convolution_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } ideep::convolution_backward_data::compute_v2( grad_y, w, @@ -828,6 +850,11 @@ Tensor mkldnn_convolution_backward_input( TORCH_WARN_ONCE( "Unexpected ideep version to support fpmath_mode_bf16, please update ideep version to align with pytorch main branch"); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + TORCH_WARN_ONCE( + "Unexpected ideep version to support fpmath_mode_tf32, please update ideep version to align with pytorch main branch"); + } #endif if (grad_output.is_mkldnn()) { @@ -858,6 +885,10 @@ std::tuple mkldnn_convolution_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias_defined) { ideep::convolution_backward_weights::compute_v2( x, @@ -1011,6 +1042,10 @@ Tensor mkldnn_convolution_transpose_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } ideep::convolution_transpose_backward_data::compute_v3( grad_y, w, @@ -1053,6 +1088,10 @@ std::tuple mkldnn_convolution_transpose_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias_defined) { ideep::convolution_transpose_backward_weights::compute_v3( x, diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 8dbb29bb3e01b..8f0b91b3e3f7e 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -73,6 +73,11 @@ static bool use_mkldnn_bf32_linear() { mkldnn_bf16_device_check(); } +static bool use_mkldnn_tf32_linear() { + return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" && + cpuinfo_has_x86_amx_fp16(); +} + Tensor mkldnn_linear( const Tensor& self, const Tensor& weight_t, const std::optional& bias_opt) { @@ -259,6 +264,9 @@ Tensor mkldnn_linear_pointwise( if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute( mkldnn_input, @@ -352,6 +360,10 @@ Tensor mkldnn_linear_pointwise_binary( op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } + if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute_binary( mkldnn_input, diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index a9c094d85989a..5a6e59fad7863 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -1,7 +1,8 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include +#include +#include #include #if !AT_MKLDNN_ENABLED() @@ -53,7 +54,7 @@ bool mkldnn_fp16_gemm( c10::Half *c, int64_t ldc) { return false; } -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -85,6 +86,13 @@ void mkldnn_matmul_i8i8i32( TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support"); } +bool use_mkldnn_tf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result) { + return false; +} + } // namespace at::native @@ -107,6 +115,10 @@ static bool use_mkldnn_bf32_matmul() { return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16"; } +static bool use_mkldnn_tf32_matmul() { + return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32"; +} + // returns an ideep::tensor // - dims: shape e.g: {M,N} // - idtype: ideep data type e.g: (f32, bf16, f16) @@ -144,7 +156,8 @@ mkldnn_gemm( bool bf16_usable = std::is_same_v && use_mkldnn_bf16_matmul(); bool fp16_usable = std::is_same_v && use_mkldnn_fp16_matmul(); bool bf32_usable = std::is_same_v && use_mkldnn_bf32_matmul(); - if ( !(bf16_usable || fp16_usable || bf32_usable) || + bool tf32_usable = std::is_same_v && use_mkldnn_tf32_matmul(); + if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) || (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) { return false; } @@ -155,6 +168,7 @@ mkldnn_gemm( op_attr = ideep::attr_t::fuse_sum(); } if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path + if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path // NOTE: View as c-contiguous to avoid extra reordering in mkldnn // Use identity: C = AB <=> C^T = B^T A^T @@ -281,7 +295,7 @@ bool mkldnn_fp16_gemm( return mkldnn_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -339,6 +353,7 @@ void mkldnn_matmul( auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2; auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result; bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul(); + bool tf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_tf32_matmul(); ideep::attr_t op_attr; // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor @@ -346,6 +361,7 @@ void mkldnn_matmul( // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum(); if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path + if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path // If alpha = 0, dose not need actually do gemm computation if (alpha == 0) return; @@ -412,70 +428,56 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ } } -bool use_mkldnn_bf16_matmul( +template +bool use_mkldnn_typed_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { + bool dtype_check = false; + if constexpr (std::is_same_v) { #if defined(__aarch64__) - if (mkldnn_bf16_device_check_arm()) { - //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1 - //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well - return ( - use_mkldnn_bf16_matmul() && - (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && - ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); - } else + if (mkldnn_bf16_device_check_arm()) { + // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. + // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 + // inputs, allow it for float as well + dtype_check = use_mkldnn_bf16_matmul() && + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)); + } +#else + dtype_check = dtype_check && use_mkldnn_bf16_matmul() && + (mat1.scalar_type() == kBFloat16); #endif - { - return ( - use_mkldnn_bf16_matmul() && - mat1.scalar_type() == kBFloat16 && - mat2.scalar_type() == kBFloat16 && - (!result.defined() || result.scalar_type() == kBFloat16) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); + } else if constexpr (std::is_same_v) { + dtype_check = dtype_check && use_mkldnn_fp16_matmul() && + (mat1.scalar_type() == kHalf); + } else if constexpr (std::is_same_v) { + dtype_check = dtype_check && + (use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) && + (mat1.scalar_type() == kFloat); } -} - -bool use_mkldnn_fp16_matmul( - const Tensor& mat1, - const Tensor& mat2, - const Tensor& result) { - - return ( - use_mkldnn_fp16_matmul() && - mat1.scalar_type() == kHalf && - mat2.scalar_type() == kHalf && - (!result.defined() || result.scalar_type() == kHalf) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); -} - -bool use_mkldnn_bf32_matmul( - const Tensor& mat1, - const Tensor& mat2, - const Tensor& result) { - - return ( - use_mkldnn_bf32_matmul() && - mat1.scalar_type() == kFloat && - mat2.scalar_type() == kFloat && - (!result.defined() || result.scalar_type() == kFloat) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); + if (!dtype_check) { + return false; + } + bool size_check = + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2); + dtype_check = (mat1.scalar_type() == mat2.scalar_type()) && + (!result.defined() || result.scalar_type() == mat1.scalar_type()); + return dtype_check && size_check; } bool use_mkldnn_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { - return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result)); + auto mat1_type = mat1.scalar_type(); + if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) { + return false; + } + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] { + return use_mkldnn_typed_matmul(mat1, mat2, result); + }); + return false; } static void _mkldnn_matmul_i8i8i32_with_primitive( diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index e783d23724030..80247497d58f0 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -29,6 +29,11 @@ bool use_mkldnn_bf32_matmul( const Tensor& mat2, const Tensor& result_opt); +bool use_mkldnn_tf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result_opt); + // Try running mkldnn optimized gemm, or returns false if naive gemm would be faster bool mkldnn_bf16_gemm( TransposeType transa, TransposeType transb, @@ -62,7 +67,7 @@ oneDNN implicit reduced precision arithmetic feature https://github.com/mgouicem/oneDNN/tree/mgouicem/rfcs/implicit_downconvert/rfcs/20210301-computation-datatype to allow implicitly cast data type from FP32 to BF16 in onednn compute primitives */ -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, diff --git a/docs/source/notes/mkldnn.rst b/docs/source/notes/mkldnn.rst index 366c2f99cd6f2..8e4a26e50bc55 100644 --- a/docs/source/notes/mkldnn.rst +++ b/docs/source/notes/mkldnn.rst @@ -65,6 +65,13 @@ To get an idea of the precision and speed, see the example code and benchmark da relative_error = error / mean # 0.0170 print(error, relative_error) + # Do matmul at TF32 mode. + torch.backends.mkldnn.matmul.fp32_precision = 'tf32' + ab_tf32 = a @ b # expected speedup with TF32 dot-product acceleration + error = (ab_tf32 - ab_full).abs().max() # 0.0004 + relative_error = error / mean # 0.00000552 + print(error, relative_error) + # Do matmul FP32 mode. torch.backends.mkldnn.matmul.fp32_precision = 'ieee' ab_fp32 = a @ b diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index bccc0e6e42fda..79ca002f7f5bf 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -18,7 +18,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -312,6 +312,11 @@ def forward(self, x): memory_format, dtype, ) in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -350,7 +355,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_unary(self, device): self.device = device self._test_conv_unary_base(dim=4) @@ -358,7 +363,7 @@ def test_conv2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d_unary(self, device): self.device = device self._test_conv_unary_base(dim=5) @@ -442,7 +447,7 @@ def matcher_check_fn(): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose2d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=4) @@ -453,7 +458,7 @@ def test_conv_transpose2d_unary(self, device): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose3d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=5) @@ -508,6 +513,11 @@ def forward(self, x): memory_format, dtype, ) in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -543,7 +553,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off(0.02) + @reduced_f32_on_and_off(0.02) def test_conv2d_binary(self, device): self.device = device self._test_conv_binary_base(dim=4) @@ -551,7 +561,7 @@ def test_conv2d_binary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off(0.02) + @reduced_f32_on_and_off(0.02) def test_conv3d_binary(self, device): self.device = device self._test_conv_binary_base(dim=5) @@ -650,7 +660,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=4) @@ -658,7 +668,7 @@ def test_conv2d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=5) @@ -667,7 +677,7 @@ def test_conv3d_binary_broadcast_shapes(self, device): @skipIfNoONEDNN @skipIfRocm @unittest.skipIf(IS_FBCODE, "Failing in fbcode") - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_linear_add_broadcast_shapes(self, device): self.device = device @@ -699,7 +709,7 @@ def matcher_check_fn(): class TestPatternMatcher(TestPatternMatcherBase): - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_linear_unary(self, device="cpu"): self.device = device @@ -730,10 +740,15 @@ def forward(self, x): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) - if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 @@ -761,7 +776,7 @@ def matcher_check_fn(): expected_kernel_count -= 1 self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) - @bf32_on_and_off() + @reduced_f32_on_and_off() @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self, device="cpu"): self.device = device @@ -909,7 +924,7 @@ def matcher_check_fn(): # 1 kernel for "to_lowp", 2 kernels for unary ops self.assertEqual(metrics.generated_kernel_count, 3) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_linear_binary(self, device="cpu"): self.device = device @@ -931,7 +946,7 @@ def forward(self, x, y): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) - if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) options = itertools.product( binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes @@ -940,6 +955,11 @@ def forward(self, x, y): for binary_fn, input_shape, bias, dtype in options: metrics.reset() + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue def matcher_check_fn(): self.assertEqual( diff --git a/test/test_linalg.py b/test/test_linalg.py index abbf7d6f6e9e9..8712d65bb493c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -40,7 +40,7 @@ _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \ _group_quantize_tensor_symmetric -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.distributions.binomial import Binomial import torch.backends.opt_einsum as opt_einsum import operator @@ -231,7 +231,7 @@ def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=N @dtypes(torch.float, torch.cfloat) @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) @tf32_on_and_off(5e-3) - @bf32_on_and_off(5e-3) + @reduced_f32_on_and_off(5e-3) def test_inner(self, device, dtype): def check(a_sizes_, b_sizes_): for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): @@ -785,7 +785,7 @@ def cholesky_test_helper(n, batch_dims, upper): @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @tf32_on_and_off(0.1 if TEST_WITH_ROCM else 0.01) - @bf32_on_and_off(0.01) + @reduced_f32_on_and_off(0.01) def test_old_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -7199,7 +7199,7 @@ def maybe_transpose(cond, m): *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) @@ -7209,7 +7209,7 @@ def test_addmm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_relu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) @@ -7221,7 +7221,7 @@ def test_addmm_relu(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_relu_tunableop_rocm(self, device, dtype): with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) @@ -7235,14 +7235,14 @@ def test_addmm_relu_tunableop_rocm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_gelu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) @dtypes(torch.float, torch.double) @dtypesIfCUDA(*floating_and_complex_types()) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: for n in [0, 1, 10]: @@ -7840,7 +7840,7 @@ def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k): @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble) @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) @tf32_on_and_off(0.01) - @bf32_on_and_off(0.01) + @reduced_f32_on_and_off(0.01) def test_mm(self, device, dtype): def _test_mm(n, m, p, dtype, genf): # helper function @@ -8020,7 +8020,7 @@ def test_strided_mm_bmm(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_bmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -8133,7 +8133,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -8207,7 +8207,7 @@ def generate_tensor(): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_baddbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -9167,7 +9167,7 @@ def dims_full_for_fn(): # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly @tf32_on_and_off(0.002 if torch.version.hip else 0.001) - @bf32_on_and_off(0.001) + @reduced_f32_on_and_off(0.001) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) m_dim = random.randint(1, 8) @@ -9504,7 +9504,7 @@ def fn(torchfn, *args): fn(torch.slogdet, (0, 0))) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.07) + @reduced_f32_on_and_off(0.07, 0.005) def test_tensordot(self, device): a = torch.arange(60., device=device).reshape(3, 4, 5) b = torch.arange(24., device=device).reshape(4, 3, 2) diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 23788653cc6cd..e2ec92fc8dada 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -27,7 +27,7 @@ instantiate_device_type_tests, dtypes, ) -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off # batched grad doesn't support mkldnn gradcheck = functools.partial(gradcheck, check_batched_grad=False) @@ -284,15 +284,15 @@ def _test_conv_base(self, dim): if bias: self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv1d(self): self._test_conv_base(dim=1) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d(self): self._test_conv_base(dim=2) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d(self): self._test_conv_base(dim=3) @@ -407,7 +407,7 @@ def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) @@ -443,7 +443,7 @@ def test_conv_nhwc_lower_precision(self, dtype): self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) @@ -532,15 +532,15 @@ def _test_conv_transpose_base(self, dim): if bias: self.assertEqual(conv.bias.grad, conv_ref.bias.grad) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose1d(self): self._test_conv_transpose_base(dim=1) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose2d(self): self._test_conv_transpose_base(dim=2) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) @@ -1680,21 +1680,29 @@ def test_mlkdnn_get_set(self): # get/set mkldnn ops with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="tf32"): + self.assertEqual(torch.backends.mkldnn.fp32_precision, "tf32") with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): self.assertEqual(torch.backends.mkldnn.fp32_precision, "none") # get/set matmul torch.backends.mkldnn.matmul.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + torch.backends.mkldnn.matmul.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") torch.backends.mkldnn.matmul.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") # get/set conv torch.backends.mkldnn.conv.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16") + torch.backends.mkldnn.conv.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "tf32") torch.backends.mkldnn.conv.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none") # get/set rnn torch.backends.mkldnn.rnn.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16") + torch.backends.mkldnn.rnn.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "tf32") torch.backends.mkldnn.rnn.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none") @@ -1710,18 +1718,14 @@ def test_default_use_parent(self): torch.backends.mkldnn.matmul.fp32_precision = "none" with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="tf32"): + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): with torch.backends.flags(fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") with torch.backends.flags(fp32_precision="tf32"): - # when parent is a not supported precision, use default - self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") - @recover_orig_fp32_precision - def test_invalid(self): - # use default if user set a not supported precision - torch.backends.mkldnn.matmul.fp32_precision = "tf32" - self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/test/test_nn.py b/test/test_nn.py index 0323080728b3c..218a65f388f04 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -55,7 +55,7 @@ from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_off, tf32_on from torch.types import _TensorOrTensors -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() @@ -8278,7 +8278,7 @@ def _test_module_empty_inputs(self, module, inputs): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off() - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_affine_2d_rotate0(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8319,7 +8319,7 @@ def test_affine_2d_rotate0(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.01 if TEST_WITH_ROCM else 0.001) - @bf32_on_and_off(0.001) + @reduced_f32_on_and_off(0.001) def test_affine_2d_rotate90(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8369,7 +8369,7 @@ def test_affine_2d_rotate90(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_2d_rotate45(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8447,7 +8447,7 @@ def test_avg_pool_large_tensor2(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_2d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8500,7 +8500,7 @@ def test_affine_2d_rotateRandom(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_3d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. diff --git a/test/test_torch.py b/test/test_torch.py index ba830489a99ba..7af57f23b8fed 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -59,7 +59,7 @@ from torch.testing._internal.common_cuda import ( tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers) -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, all_types_and, floating_types, floating_and_complex_types, integral_types_and, @@ -2557,7 +2557,7 @@ def test_cdist_cuda_backward(self, device): self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.08) + @reduced_f32_on_and_off(0.08) def test_cdist_large(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(1000, 10, device=device) @@ -2568,7 +2568,7 @@ def test_cdist_large(self, device): @slowTest @tf32_on_and_off(0.01) - @bf32_on_and_off(0.08) + @reduced_f32_on_and_off(0.08) def test_cdist_large_batch(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 1000, 10, device=device) @@ -2578,7 +2578,7 @@ def test_cdist_large_batch(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) - @bf32_on_and_off(0.04) + @reduced_f32_on_and_off(0.04) def test_cdist_non_contiguous(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(5, 7, device=device).mT @@ -2606,7 +2606,7 @@ def test_cdist_non_contiguous(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) - @bf32_on_and_off(0.04) + @reduced_f32_on_and_off(0.04) def test_cdist_non_contiguous_batch(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 2, 5, 7, device=device).mT diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index a269b17e3a2a9..e5a0c0dc51c5d 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1228,11 +1228,14 @@ def is_const_or_cat_by_const(weight): torch.bfloat16, torch.float16, ) - bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] - use_bf16_for_fp32_weight = ( - bf32_matmul_enabled and weight_meta_value.dtype == torch.float32 + reduced_f32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision in [ # type: ignore[attr-defined] + "bf16", + "tf32", + ] + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_meta_value.dtype == torch.float32 ) - compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. # on aarch64, use mkldnn op for fp32 as well if acl is enabled if ( @@ -1449,13 +1452,13 @@ def linear(match, *args, **kwargs): torch.bfloat16, torch.float16, ) - bf32_matmul_enabled = ( - torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] + reduced_f32_matmul_enabled = ( + torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"] # type: ignore[attr-defined] ) - use_bf16_for_fp32_weight = ( - bf32_matmul_enabled and weight_dtype == torch.float32 + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_dtype == torch.float32 ) - compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight batch_size = input.meta.get("val").shape[0] if has_free_symbols(batch_size): assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), ( diff --git a/torch/testing/_internal/common_mkldnn.py b/torch/testing/_internal/common_mkldnn.py index ffaed6c7e009c..44da60a5ad1fe 100644 --- a/torch/testing/_internal/common_mkldnn.py +++ b/torch/testing/_internal/common_mkldnn.py @@ -7,9 +7,6 @@ import torch -# Test whether hardware BF32 math mode enabled. It is enabled only on: -# - MKLDNN is available -# - BF16 is supported by MKLDNN def bf32_is_not_fp32(): if not torch.backends.mkldnn.is_available(): return False @@ -18,8 +15,16 @@ def bf32_is_not_fp32(): return True +def tf32_is_not_fp32(): + if not torch.backends.mkldnn.is_available(): + return False + if not torch._C._cpu._is_amx_fp16_supported(): + return False + return True + + @contextlib.contextmanager -def bf32_off(): +def reduced_f32_off(): old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision old_conv_precision = torch.backends.mkldnn.conv.fp32_precision try: @@ -47,19 +52,39 @@ def bf32_on(self, bf32_precision=1e-2): self.precision = old_precision -# This is a wrapper that wraps a test to run this test twice, one with -# allow_bf32=True, another with allow_bf32=False. When running with -# allow_bf32=True, it will use reduced precision as specified by the -# argument -def bf32_on_and_off(bf32_precision=1e-2): - def with_bf32_disabled(self, function_call): - with bf32_off(): +@contextlib.contextmanager +def tf32_on(self, tf32_precision=1e-5): + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision + old_precision = self.precision + try: + torch.backends.mkldnn.matmul.fp32_precision = "tf32" + torch.backends.mkldnn.conv.fp32_precision = "tf32" + self.precision = tf32_precision + yield + finally: + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision + self.precision = old_precision + + +# This is a wrapper that wraps a test to run this test three times, one with +# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32 +# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/ +# tf32) as specified by the argument. +def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5): + def with_reduced_f32_disabled(self, function_call): + with reduced_f32_off(): function_call() def with_bf32_enabled(self, function_call): with bf32_on(self, bf32_precision): function_call() + def with_tf32_enabled(self, function_call): + with tf32_on(self, tf32_precision): + function_call() + def wrapper(f): params = inspect.signature(f).parameters arg_names = tuple(params.keys()) @@ -67,14 +92,19 @@ def wrapper(f): @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) - cond = bf32_is_not_fp32() + cond = True if "device" in kwargs: cond = cond and (torch.device(kwargs["device"]).type == "cpu") if "dtype" in kwargs: cond = cond and (kwargs["dtype"] == torch.float) - if cond: - with_bf32_disabled(kwargs["self"], lambda: f(**kwargs)) - with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) + bf32_cond = cond and bf32_is_not_fp32() + tf32_cond = cond and tf32_is_not_fp32() + if bf32_cond or tf32_cond: + with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs)) + if bf32_cond: + with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) + if tf32_cond: + with_tf32_enabled(kwargs["self"], lambda: f(**kwargs)) else: f(**kwargs) From eeda1a75ace75ce8a6763050fb91d236a6d3287b Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 17 Jul 2025 09:39:47 +0000 Subject: [PATCH 176/457] Forward-fix unused variables warning/error (#158549) Introduced in https://github.com/pytorch/pytorch/pull/158037, didn't seem to trigger on PR, but trunk CI is failing in some `linux-jammy-cpu-py3.12-gcc11-inductor-*` jobs where this warning is turned into an error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158549 Approved by: https://github.com/danthe3rd --- aten/src/ATen/cuda/CUDABlas.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index acb1d5ed8b0da..cf403365b2df2 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1993,8 +1993,8 @@ void scaled_gemm( // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, // but we must invoke get_scale_mode anyways to trigger the version checks. - int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); - int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); + [[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); + [[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); #if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); From 3f8e2e91ad05a7a9da0590bf0238b1ae97b11455 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 17 Jul 2025 14:55:13 +0800 Subject: [PATCH 177/457] [BE][15/16] fix typos in torch/ (torch/distributed/tensor/) (#156605) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156605 Approved by: https://github.com/wanchaol, https://github.com/albanD --- .lintrunner.toml | 1 - test/distributed/tensor/test_utils.py | 2 +- tools/linter/dictionary.txt | 5 ++--- torch/distributed/tensor/_api.py | 4 ++-- torch/distributed/tensor/_collective_utils.py | 2 +- torch/distributed/tensor/_dispatch.py | 2 +- torch/distributed/tensor/_dtensor_spec.py | 2 +- torch/distributed/tensor/_op_schema.py | 2 +- torch/distributed/tensor/_ops/_embedding_ops.py | 4 ++-- torch/distributed/tensor/_ops/_math_ops.py | 2 +- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- torch/distributed/tensor/_ops/_view_ops.py | 2 +- torch/distributed/tensor/_ops/utils.py | 2 +- torch/distributed/tensor/_shards_wrapper.py | 4 ++-- torch/distributed/tensor/_utils.py | 2 +- torch/distributed/tensor/debug/_comm_mode.py | 2 +- .../examples/comm_mode_features_example.py | 2 +- .../tensor/examples/convnext_example.py | 2 +- .../tensor/examples/torchrec_sharding_example.py | 4 ++-- .../tensor/experimental/_attention.py | 16 ++++++++-------- .../distributed/tensor/experimental/_func_map.py | 2 +- .../tensor/experimental/_register_sharding.py | 2 +- .../tensor/parallel/_data_parallel_utils.py | 2 +- torch/distributed/tensor/parallel/ddp.py | 2 +- torch/distributed/tensor/parallel/fsdp.py | 2 +- torch/distributed/tensor/placement_types.py | 2 +- 26 files changed, 37 insertions(+), 39 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 5638a441b8e58..be03ced182f91 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1169,7 +1169,6 @@ exclude_patterns = [ 'aten/src/ATen/[a-mA-M]*/**', 'test/**', 'test/[a-hA-h]*/**', - 'torch/distributed/tensor/**', ] init_command = [ 'python3', diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 8629ca5261cf8..dbfbac12223bb 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -179,7 +179,7 @@ def test_compute_global_tensor_shape_1D_invalid_shape(self): ) with self.assertRaisesRegex( RuntimeError, - "Non-sharded dimentions should have identical size across ranks.", + "Non-sharded dimensions should have identical size across ranks.", ): _ = compute_global_tensor_shape( local_shape, diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 49ae353c7d02e..1817b04567ab7 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -4,9 +4,8 @@ BU contiguities contiguity coo -Din -Dout -dOut +din +dout ElementE followings fro diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 19c8739e0581f..b0ee136c135f6 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -571,7 +571,7 @@ def full_tensor( """ Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate - them together. It's a syntatic sugar of the following code: + them together. It's a syntactic sugar of the following code: ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` @@ -1011,7 +1011,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # set default placements to replicated if not specified placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) - # check device_mesh againts placements + # check device_mesh against placements assert device_mesh.ndim == len(placements), ( "mesh dimension does not match the length of placements" ) diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index a1e38aec651bf..4fce6fea538a6 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -316,7 +316,7 @@ def redistribute_cost( NOTE: 1. Only consider communication cost here, since computation costs for redistribute - are quite trival (i.e. we only need to narrow or simple division) + are quite trivial (i.e. we only need to narrow or simple division) 2. Only consider redistribute cost on same mesh, cross mesh communication cost is not quite needed for operator strategy estimation/selection. """ diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 83270b5a64bb7..1d0f57102aaec 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -434,7 +434,7 @@ def _try_replicate_spec_for_scalar_tensor( "Found a non-scalar tensor with numel=1 and ndim!=0, " "we are implicitly creating a replicated DTensor for it. " "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed enviroment." + "or explicitly create a DTensor under distributed environment." ) if tensor_arg.numel() == 1 or self._allow_implicit_replication: diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 48739db536a9b..c450720357ba8 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -244,7 +244,7 @@ def from_dim_map( if placement.is_shard(): placement = cast(Shard, placement) raise RuntimeError( - f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" ) elif placement.is_partial(): raise RuntimeError( diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index c359f28eb3efc..ccc006e63a83a 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -28,7 +28,7 @@ PlacementList = list[Optional[Placement]] -# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should # be the same set of possibilities. OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 9d316aff4ed80..1b8e47895ce59 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -113,7 +113,7 @@ def _partition_value( def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask + # by the time we need reduction, we should have already saved the mask assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction @@ -134,7 +134,7 @@ def _reduce_shard_value( mesh_dim: int, shard_spec: Placement, ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask + # by the time we need reduction, we should have already saved the mask assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index c1bb96d9c319b..59ca7aed9bdf8 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -1085,7 +1085,7 @@ def topk_strategy(op_schema: OpSchema) -> OpStrategy: if dim != topk_dim: dim_shardings: PlacementList = [Shard(dim)] * 3 single_mesh_dim_strategies.append(dim_shardings) - # TODO: topk on sharded dim requries non-trival reduction, address it later + # TODO: topk on sharded dim requires non-trival reduction, address it later return expand_to_full_mesh_op_strategy( input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 6b662aca4912a..fa4446a2d15eb 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -704,7 +704,7 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate None, # max_k None, # philox_seed None, # philox_offset - # NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor + # NOTE: debug_attn_mask is not supported by pytorch and is always an empty tensor # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840 debug_attn_mask_sharding, # debug_attn_mask Replicate(), # q diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 8fe213f39846e..c942da67cd8a1 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -300,7 +300,7 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: Flatten((InputDim(1), InputDim(2))) ) - - ouptut dimension 0 maps to input dimension 0 + - output dimension 0 maps to input dimension 0 - output dimension 1 maps to a flattened input dimensions 1 and 2 diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index f6dd44cdfb08e..f120b6c39b022 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -216,7 +216,7 @@ def map_placements_after_broadcast( # the input shape shard dim before broadcasting, # in this case it means implicit broadcasting happen # in this dim, so we can just mark it as replicate - # and implict broadcast will broadcast automatically + # and implicit broadcast will broadcast automatically # to the sharded shape new_placements.append(Replicate()) diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index 30cc25ae89a66..a3798eac4ae0d 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -27,7 +27,7 @@ class LocalShardsWrapper(torch.Tensor): """ A wrapper class to hold local shards of a DTensor. - This class is used largely for checkpointing purposes and implicity subtypes + This class is used largely for checkpointing purposes and implicitly subtypes the _Checkpointable protocol. """ @@ -159,7 +159,7 @@ def handle_view(args, kwargs) -> "LocalShardsWrapper": ] elif args[0].local_shards()[0].ndim == 1: assert args[0].storage_metadata().size[0] == view_shape[0] - # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + # This case is for optimizer sharding as regardless of sharding type, optimizer state is row wise sharded res_shards_list = [ aten.view.default(shard, shard.shape, **kwargs) for shard in args[0].local_shards() diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 92ea70eb16a85..6521eeac9b3ea 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -296,7 +296,7 @@ def compute_global_tensor_shape( for shape_tensor in gathered_shaped_tensors: if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): raise RuntimeError( - "Non-sharded dimentions should have identical size across ranks." + "Non-sharded dimensions should have identical size across ranks." ) shape_tensor_list = shape_tensor.tolist() sharded_dim_sum += shape_tensor_list[shard_dim] diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 570161b676823..99978f9cc6b5e 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -395,7 +395,7 @@ def add_json_information(json_dict, fqn): json_dict: dict[str, Any] = {} add_json_information(json_dict, "Global") - # converts dictonary into json file + # converts dictionary into json file with open(file_name, "w") as json_file: json.dump(json_dict, json_file, indent=4) diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index da004aef4071f..3a8ca45b8aaff 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -711,7 +711,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def run_example(world_size: int, rank: int, example_name: str) -> None: # set manual seed - # intializing class with all of the functions + # initializing class with all of the functions instantiated_example = CommDebugModeExample(world_size, rank) # dict that stores example code function names name_to_example_code: dict[str, Callable[[], None]] = { diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index 9a3c2bbabd9ee..994f2ee10f69b 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs """ The following example demonstrates how to train a ConvNeXt model -with intermediate activations sharded across mutliple GPUs via DTensor +with intermediate activations sharded across multiple GPUs via DTensor To run the example, use the following command: torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index f66ea658daf4b..2c5d104136102 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -231,7 +231,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size): # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the - # assumption. Torchrec needs to pass in this information explicitely. + # assumption. Torchrec needs to pass in this information explicitly. # shape/stride are global tensor's shape and stride dtensor = DTensor.from_local( local_shards_wrapper, # a torch.Tensor subclass @@ -324,7 +324,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): # create a DTensor from the local shard for the current table # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the - # assumption. Torchrec needs to pass in this information explicitely. + # assumption. Torchrec needs to pass in this information explicitly. dtensor = DTensor.from_local( local_shards, device_submesh, diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 73b53f051421d..457624bd6a674 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -239,7 +239,7 @@ def next_buffer(self) -> torch.Tensor: class _AllGatherRotater(_RingRotater): """ - Allgather the kv and return the only the requried kv. + Allgather the kv and return the only the required kv. Only one communication will be done. """ @@ -277,7 +277,7 @@ def _create_rotater( elif method == _RotateMethod.ALL_GATHER: return _AllGatherRotater(pg, seq_dim) else: - raise NotImplementedError(f"Unkonwn method {method}") + raise NotImplementedError(f"Unknown method {method}") def _templated_ring_attention( @@ -339,12 +339,12 @@ def _templated_ring_attention( First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the no-load-balance case. This iteration corresponds to the `if` of the - (`if, `elif`, `else`) in the implemementation. + (`if, `elif`, `else`) in the implementation. Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and (k0, k3). For rank0, no computation is needed for q0. However, computations for q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the - `else` of the (`if`, `elif`, `else`) in the implemementation. + `else` of the (`if`, `elif`, `else`) in the implementation. For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. @@ -916,7 +916,7 @@ def _distribute_function( the inputs and outputs of a function. Similar to ``distribute_module``, this API installs hooks to the ``fn`` to convert the inputs and outputs. There are two major differences between ``distribute_function`` and ``distribute_module``. - First, a function does not have parammeters and buffers, as a result, + First, a function does not have parameters and buffers, as a result, ``distribute_function`` itself won't convert any parameters/buffers but simply install the input and output hooks. The tensor conversion will happen in the hooks. Another difference is an nn.Module subclass can have several instances and each @@ -932,9 +932,9 @@ def _distribute_function( ``fn_module`` is ``torch.nn.functional``. device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the input and output hooks to distribute the tensors. - input_fn (Optioinal[Callable]): the hook to distribute or convert the input + input_fn (Optional[Callable]): the hook to distribute or convert the input arguments of ``fn``. - output_fn (Optioinal[Callable]): the hook to distribute or convert the output + output_fn (Optional[Callable]): the hook to distribute or convert the output arguments of ``fn``. """ @@ -989,7 +989,7 @@ class _AttentionContextParallel(ParallelStyle): Applies context parallel optimizations to the attention layer. This will work for nn.MultiHeadedAttention and custom attention layers that - call F.scaled_dotproduct_attention with a simliar signature. + call F.scaled_dotproduct_attention with a similar signature. This expects the `forward` method consumes either: diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index 7eb2e72343e21..fd91328c0b379 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -112,7 +112,7 @@ def local_map( >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> - >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion + >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], diff --git a/torch/distributed/tensor/experimental/_register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py index f91fae4580bcc..b286b151efed5 100644 --- a/torch/distributed/tensor/experimental/_register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -41,7 +41,7 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]): as the original op (except that if an arg is a :class:`torch.Tensor`, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its - corresponding intput placements. + corresponding input placements. Example: >>> # xdoctest: +SKIP("distributed") diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index 6513123e24628..c41da260a02f9 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -30,7 +30,7 @@ def _flatten_tensor( @no_type_check def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): - # unflatten would mainly be called everytime FSDP allgather parameters. + # unflatten would mainly be called every time FSDP allgather parameters. result = DTensor.from_local( tensor, spec.mesh, diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 39ab299b4f79f..7b19f97675197 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -36,7 +36,7 @@ def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]): def _reconstruct_dtensor(module: nn.Module, _input: Any): """ - Recontruct DTensor parameters from local tensors + Reconstruct DTensor parameters from local tensors """ param_list = [] # TODO: To add perf optimizations to this iterations diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 5282542950c4d..1b0b8cac7c760 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -326,7 +326,7 @@ def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None self.device_handle = device_handle - # we have to use the dynamo disable this way to disable dynamo as the decorater way would + # we have to use the dynamo disable this way to disable dynamo as the decorator way would # trigger build failure with torch deploy... self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] self.post_unflatten_transform diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index a8fdd7bec1ac9..b37d49bd30744 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -701,7 +701,7 @@ def _partition_value( # _partition_value: partition the value of a replicated tensor on the mesh dimension # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a divison operation + # - i.e. _partition_value on a sum reduce op is just a division operation # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation # TODO: if the reduce_op is min/max, etc. the _partition_value should be a # different operation From c8d43cbc6e2178c8971be46ea020107136b35355 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 17 Jul 2025 14:55:14 +0800 Subject: [PATCH 178/457] [BE][3/6] fix typos in test/ (#157637) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157637 Approved by: https://github.com/yewentao256, https://github.com/albanD ghstack dependencies: #156605 --- .lintrunner.toml | 1 - test/ao/sparsity/test_activation_sparsifier.py | 2 +- test/ao/sparsity/test_data_sparsifier.py | 4 ++-- test/ao/sparsity/test_scheduler.py | 2 +- test/benchmark_utils/test_benchmark_utils.py | 6 ++++-- test/cpp/api/transformer.cpp | 2 +- test/cpp/jit/CMakeLists.txt | 2 +- test/cpp/jit/test_backend.cpp | 2 +- test/cpp/jit/test_backend_compiler_lib.cpp | 2 +- test/cpp/jit/test_lite_trainer.cpp | 2 +- .../test_lite_interpreter_runtime.cpp | 2 +- test/cpp/tensorexpr/test_cuda.cpp | 2 +- test/cpp/tensorexpr/test_kernel.cpp | 10 +++++----- test/cpp/tensorexpr/test_memdependency.cpp | 8 ++++---- test/cpp/tensorexpr/test_reductions.cpp | 2 +- test/cpp/tensorexpr/test_registerizer.cpp | 8 ++++---- test/cpp/tensorexpr/test_simplify.cpp | 16 ++++++++-------- test/export/test_converter.py | 6 +++--- test/export/test_export.py | 8 ++++---- test/export/test_torchbind.py | 2 +- test/functorch/attn_ft.py | 2 +- test/functorch/test_aotdispatch.py | 6 +++--- test/functorch/test_control_flow.py | 10 +++++----- test/functorch/test_vmap.py | 4 ++-- test/fx/test_lazy_graph_module.py | 6 +++--- test/fx/test_partitioner_order.py | 6 +++--- test/fx/test_pass_infra.py | 2 +- test/higher_order_ops/test_invoke_quant.py | 2 +- test/higher_order_ops/test_invoke_subgraph.py | 4 ++-- tools/linter/dictionary.txt | 9 +++++++++ torch/csrc/jit/tensorexpr/loopnest.cpp | 4 ++-- 31 files changed, 77 insertions(+), 67 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index be03ced182f91..513d7dd2d00fc 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1168,7 +1168,6 @@ exclude_patterns = [ 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', 'test/**', - 'test/[a-hA-h]*/**', ] init_command = [ 'python3', diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 9c2a10d3355a1..8e1525b858795 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -50,7 +50,7 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf sparsifier_defaults = activation_sparsifier.defaults combined_defaults = {**defaults, "sparse_config": sparse_config} - # more keys are populated in activation sparsifier (eventhough they may be None) + # more keys are populated in activation sparsifier (even though they may be None) assert len(combined_defaults) <= len(activation_sparsifier.defaults) for key, config in sparsifier_defaults.items(): diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 9994382157435..5217049aafdfd 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -265,7 +265,7 @@ def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase): r"""This helper test class takes in any supported type of and runs some tests. This inherits the TestBaseDataSparsifierRuner wherein some functions are - over-ridden to take accomodate the specific sparsifier. + over-ridden to take accommodate the specific sparsifier. TODO: Change the structure by creating a separate test case class for each member function """ @@ -770,7 +770,7 @@ def test_ptq_quantize_first(self): # higher threshold as quantization occurs before sparsity threshold = ( - 1 # zero points seem to have higher magnitude with sparsity occuring after + 1 # zero points seem to have higher magnitude with sparsity occurring after ) sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() diff --git a/test/ao/sparsity/test_scheduler.py b/test/ao/sparsity/test_scheduler.py index 38e8fca4cdd84..b563efac73bd7 100644 --- a/test/ao/sparsity/test_scheduler.py +++ b/test/ao/sparsity/test_scheduler.py @@ -188,7 +188,7 @@ def test_step(self): self.assertEqual( self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels, - msg="Sparsity level is not reaching the target level afer delta_t * n steps ", + msg="Sparsity level is not reaching the target level after delta_t * n steps ", ) diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index 1d8d8e7e35948..f9120c26a132f 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -699,14 +699,16 @@ def custom_transforms(fn: str): 8959166 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] ... 92821 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] - 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] + 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] # codespell:ignore 91000 /data/users/test_user/repos/pyto ... nsors::get_default_scalar_type() 90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so] 90000 build/../c10/core/TensorImpl.h:c ... ch/torch/lib/libtorch_python.so] 90000 build/../aten/src/ATen/record_fu ... torch/torch/lib/libtorch_cpu.so] 90000 /data/users/test_user/repos/pyto ... uard(std::optional) 90000 /data/users/test_user/repos/pyto ... ersionCounter::~VersionCounter() - 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""", + 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""".replace( + " # codespell:ignore", "" + ), ) self.regularizeAndAssertExpectedInline( diff --git a/test/cpp/api/transformer.cpp b/test/cpp/api/transformer.cpp index 6062c77f5917d..fc4832d30157a 100644 --- a/test/cpp/api/transformer.cpp +++ b/test/cpp/api/transformer.cpp @@ -73,7 +73,7 @@ void transformer_encoder_layer_test_helper( ASSERT_TRUE( torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - // all 0 values are NOT masked. This should't mask anything + // all 0 values are NOT masked. This shouldn't mask anything torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1; result = model( encoder_input, diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index d192d8a6c5d35..f58d81ed008ab 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -17,7 +17,7 @@ set(BACKEND_WITH_COMPILER_SRCS ) if(USE_KINETO) # Testing edge profiler for backend use - # profiler_edge should only be aded when USE_KINETO flag is on + # profiler_edge should only be added when USE_KINETO flag is on list(APPEND BACKEND_WITH_COMPILER_SRCS ${TORCH_SRC_DIR}/csrc/jit/mobile/profiler_edge.cpp) endif() diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index dd4df40d9c138..4a060e436f2b0 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -789,7 +789,7 @@ TEST( c._save_for_mobile(ss, ExtraFilesMap(), true); auto c_loaded = _load_for_mobile(ss); /* - * Erro stack trace will look like this: + * Error stack trace will look like this: * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 33262efd1e2b1..55511c3e684a6 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -79,7 +79,7 @@ class BackendWithCompiler : public PyTorchBackendInterface { // forwards everything along. In a non toy setup this could grab information // from that runtime that might be relevant to execute, such as build flags // the resolution of the devices camera, or basically any runtime specific - // information that wouldnt be available server side where preprocess is + // information that wouldn't be available server side where preprocess is // called. c10::impl::GenericDict compile( c10::IValue processed, diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp index a09374065306b..950d0c524ad3a 100644 --- a/test/cpp/jit/test_lite_trainer.cpp +++ b/test/cpp/jit/test_lite_trainer.cpp @@ -78,7 +78,7 @@ TEST(LiteTrainerTest, Params) { AT_ASSERT(parameters[0].item() == bc_parameters[0].item()); } -// TODO Renable these tests after parameters are correctly loaded on mobile +// TODO Re-enable these tests after parameters are correctly loaded on mobile /* TEST(MobileTest, NamedParameters) { Module m("m"); diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index 088a4eb04c996..b6467b7c5b490 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -106,7 +106,7 @@ TEST(RunTimeTest, DelegateException) { * inputs.emplace_back(torch::rand({2, 4})); * inputs.emplace_back(torch::rand({13, 9})); * Run with inputs and expect exception - * Erro stack trace will look like this: + * Error stack trace will look like this: * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 8a96c68dc75e4..2e1e84e758db3 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -1681,7 +1681,7 @@ TEST(Cuda, MaskMultiDim_CUDA) { // Tests the case where loop extents are symbolic and not known at compile time. // In this case both stores must be masked against the extent of the other loop, -// incase it is larger. +// in case it is larger. TEST(Cuda, MaskMultiDimSymbolic_CUDA) { VarHandle OUTER_SIZE("OUTER_SIZE", kLong); VarHandle A_SIZE("A_SIZE", kLong); diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 22f6b64efe1a8..f9cd82ff95d04 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -1113,7 +1113,7 @@ TEST_F(Kernel, Softmax2D) { const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, // oss.str()); @@ -1192,7 +1192,7 @@ TEST_F(Kernel, Softmax3D) { ver_env.d("softmax_dim_size", softmax_dim_size); const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1275,7 +1275,7 @@ TEST_F(Kernel, Softmax4D) { ver_env.d("softmax_dim_size", softmax_dim_size); const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1887,7 +1887,7 @@ graph(%x : int, auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); // Verify that TEK::runFast works correctly with mixed scalar and tensor - // inputs/utputs + // inputs/outputs std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; k.runFast(inputs, outputs); @@ -1897,7 +1897,7 @@ graph(%x : int, ASSERT_TRUE(at::equal(rt, zt * xt)); // Verify that TEK::run works correctly with mixed scalar and tensor - // inputs/utputs + // inputs/outputs std::vector stack = {x, xt, y, yt}; k.run(stack); TORCH_CHECK_EQ(stack[0], x * y * x); diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index cac7283f2bebe..5db84eab1f509 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -437,12 +437,12 @@ TEST(MemDependency, BoundSubtractMultiDim) { ASSERT_TRUE(EQ( subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); - // Mutli dim one way partial in dim 1. + // Multi dim one way partial in dim 1. ASSERT_TRUE( EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), {{CB(4, 9), CB(0, 2)}})); - // Mutli dim one way partial in dim 2. + // Multi dim one way partial in dim 2. ASSERT_TRUE( EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), {{CB(0, 9), CB(11, 20)}})); @@ -939,7 +939,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { */ // Now let's look at the bounds of each access. - // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this + // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this // much. auto history = analyzer.getHistory(); ASSERT_EQ(history.size(), 10); @@ -1134,7 +1134,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { // this case -1. ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); // It depends on the input, but also the store in the same loop, since - // different interations of the loop depend on each other. + // different iterations of the loop depend on each other. ASSERT_EQ(history[1]->dependencies().size(), 2); ASSERT_TRUE(history[1]->hasDependency(history[0])); ASSERT_TRUE(history[1]->hasDependency(history[2])); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index bdc744ae4e033..fb83ab85b71ed 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -333,7 +333,7 @@ TEST(Reductions, ReduceMinCustomInitializer) { cg.call({in, out, std::numeric_limits::max()}); ASSERT_EQ(out[0], 10); - // With an initalizer lower than the min, that's the min. + // With an initializer lower than the min, that's the min. cg.call({in, out, 5.f}); ASSERT_EQ(out[0], 5); } diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index d6f5977789a9e..6cbd04264c321 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -1254,7 +1254,7 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap2) { * A[0] = 3; * A[x] = (A[x]) + 1; * } - * int A_2 = A[x]; // A_2 initialier + * int A_2 = A[x]; // A_2 initializer * B[x] = A_2; // * B[x + 1] = A_2; // * A_2 = C[x]; // @@ -3064,7 +3064,7 @@ TEST(Registerizer, RegisterizerHiddenAccessNo) { } // In this case the conditional access must be hoisted by two loops, there are -// two accesses here one is unhidden and the other isnt. A[0] can be +// two accesses here one is unhidden and the other isn't. A[0] can be // registerized but B[0] cannot. TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { BufHandle a("A", {10}, kInt); @@ -3422,8 +3422,8 @@ TEST(Registerizer, RegisterizerMultiDim) { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Wont registerize if only some dims match, but will still registerize distinct -// elements. +// Won't registerize if only some dims match, but will still registerize +// distinct elements. TEST(Registerizer, RegisterizerMultiDimPartial) { BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index 99a00d0d62c11..7ca2b74eaa766 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -2643,7 +2643,7 @@ TEST(Simplify, SimplifyWontReorderFloat) { VarHandle x("x", kFloat); VarHandle y("y", kFloat); // x%y - (x%y - 1) => x%y - (x%y - 1). - // We wont reorder opaque ops if they are FP. + // We won't reorder opaque ops if they are FP. ExprHandle body = (x % y) - ((x % y) - 1); ExprHandle simplified = IRSimplifier::simplify(body); @@ -2794,7 +2794,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity checking we wont do the optimization on floats. + // Sanity checking we won't do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = ((x / y) * y) + (x % y); @@ -2811,7 +2811,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the mod term doesn't match. + // Sanity check we won't do it if the mod term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2821,7 +2821,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the div term doesn't match. + // Sanity check we won't do it if the div term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2831,7 +2831,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the mul term doesn't match. + // Sanity check we won't do it if the mul term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -3013,7 +3013,7 @@ TEST(Simplify, SimplifyModRoundModPattern) { } { - // Sanity checking we wont do the optimization on floats. + // Sanity checking we won't do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); VarHandle z("z", kFloat); @@ -4264,7 +4264,7 @@ TEST(Simplify, SimplifyReorderForCond) { { // Condition uses distinct region of Tensor. - // We could reorder here wih better analysis, but we don't. Included for + // We could reorder here with better analysis, but we don't. Included for // completeness. auto body = For::make( i, @@ -4643,7 +4643,7 @@ TEST(Simplify, SimplifyFuseConditions) { } { - // Sanity check wont fuse different non-CompareSelects. + // Sanity check won't fuse different non-CompareSelects. auto body = Block::make( {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(j, Store::make(a, {1}, i), nullptr)}); diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 953246be7a7bc..d5611ad2d5795 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -700,7 +700,7 @@ def forward(self, x: torch.Tensor): else: return self.w + self.m2(x) - # Super nested, parameters neeed to lifted + # Super nested, parameters need to be lifted # multiple times. class SuperNestedM(torch.nn.Module): def __init__(self) -> None: @@ -755,7 +755,7 @@ def forward(self, x: torch.Tensor): else: return self.linear(self.m2(x)) - # Super nested, parameters neeed to lifted + # Super nested, parameters need to be lifted # multiple times. class SuperNestedM1(torch.nn.Module): def __init__(self, dim: int) -> None: @@ -771,7 +771,7 @@ def forward(self, x: torch.Tensor): return self.linear(self.m2(x)) # Super nested, even the input needs to be - # lifted recursively due to value propogation optimiztaion. + # lifted recursively due to value propagation optimization. class SuperNestedM2(torch.nn.Module): def __init__(self, dim: int) -> None: super().__init__() diff --git a/test/export/test_export.py b/test/export/test_export.py index 1c0279d565268..d1cecb55329c4 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1871,7 +1871,7 @@ def annotate_split_points(mod: torch.nn.Module, spec): for problem in [Problem1, Problem2]: m = problem() m(torch.rand(64, 64)) - # simpified torch.distributed.pipeline code + # simplified torch.distributed.pipeline code annotate_split_points(m, {"blocks.1": 1, "blocks.3": 1}) gm = export(m, (torch.rand(64, 64),)) torch.export.unflatten(gm) @@ -8096,7 +8096,7 @@ def false_fn(x): str(schema), """cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""", ) - # serdes deserailizes tuple as list + # serdes deserializes tuple as list if need_serdes_test(self._testMethodName): self.assertExpectedInline( ep.graph_module.code.strip(), @@ -9232,7 +9232,7 @@ def forward(self, x): x = torch.rand(5, 2, 2) model = Model() - # Manualy set the fake_device of fake tensors. + # Manually set the fake_device of fake tensors. x.fake_device = torch.device("cuda:0") for n, p in model.named_parameters(): p.fake_device = torch.device("cuda:0") @@ -13562,7 +13562,7 @@ def forward(self, x): self.assertTrue(torch.allclose(m(x2), ep.module()(x2))) self.assertTrue(torch.allclose(m(x1), ep.module()(x1))) - @testing.expectedFailureSerDerNonStrict # construtor is not serialized today + @testing.expectedFailureSerDerNonStrict # constructor is not serialized today @testing.expectedFailureSerDer # constructor is not serialized today @testing.expectedFailureRetraceability # dynamo doesn't work with FlatApply op def test_capture_subclass_constructor(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 3f8f11aca0e52..214f3ce2fdfa4 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1274,7 +1274,7 @@ def forward(self, tq, x): self.assertEqual(cnt.frame_count, 1) tq2 = _empty_tensor_queue() - # make first tensor's secon dim dynamic + # make first tensor's second dim dynamic tq2.push(torch.randn(2, 4, requires_grad=False)) torch.compile(mod, backend=cnt)(tq2, x) self.assertEqual(cnt.frame_count, 2) diff --git a/test/functorch/attn_ft.py b/test/functorch/attn_ft.py index ee4656631964a..7038ded094904 100644 --- a/test/functorch/attn_ft.py +++ b/test/functorch/attn_ft.py @@ -126,7 +126,7 @@ def forward( if self.position_embedding_type == "relative_key": # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators - # eventhough they are degenerate matmuls + # even though they are degenerate matmuls relative_position_scores = (q * positional_embedding).sum(features) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index f1d1c92d52f47..698bab89935fb 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2372,7 +2372,7 @@ def f(a, b): return a.mul(3), b.mul(4) inp = [ - # First inp doesnt require grad, but we switch it on + # First inp doesn't require grad, but we switch it on torch.ones(3, 3, requires_grad=False), torch.ones(3, 3, requires_grad=True), ] @@ -5670,7 +5670,7 @@ def f(a, b, c, d): _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, - # then 4 tensors (transposes of matricies used for mm) are saved + # then 4 tensors (transposes of matrices used for mm) are saved # finally 3 symints are saved [False, True, True, False, False] + [False] * 4 + [True] * 3, [is_sym_node(n) for n in fw_graph_out_nodes], @@ -6000,7 +6000,7 @@ def f(a, b): self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) - # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. + # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 1508997384d2f..fbcb8fc2e19b9 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -2013,7 +2013,7 @@ def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd): if autograd: self.check_autograd(result, expected_result, (init, inp)) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail # Fails with: AssertionError: scan is not an OpOverload @skipIfRocm(msg="Unsupported on ROCM yet") @@ -4143,7 +4143,7 @@ def second_chain_fct(scan_fct, inp, **kwargs): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4241,7 +4241,7 @@ def body_fn(ind, loop_val): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4314,7 +4314,7 @@ def combine_fn(x, y): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -5315,7 +5315,7 @@ def forward(self, arg0_1): ) @parametrize("func_type", ["no", "cpp", "python", "functorch"]) - # - "simple_with_linear" and "nested_with_linear" doesn't work becaue parameters and buffers + # - "simple_with_linear" and "nested_with_linear" doesn't work because parameters and buffers # are not inputs so they're not wrapped by functionalization and tracing. # # - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output" diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6ba61a6c1d0d3..0f893201733d3 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4152,7 +4152,7 @@ def test(): with subtest_ctx(self), skip_xfail_ctx(self): args = (sample_input.input,) + sample_input.args if not any(isinstance(arg, torch.Tensor) for arg in args): - # Atleast one tensor required for vmap. + # At least one tensor required for vmap. continue kwargs = sample_input.kwargs is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) @@ -4230,7 +4230,7 @@ def sample_vmap_out_dim_numpy_split_copy_with_int( xfail("as_strided_copy"), xfail( "as_strided_scatter" - ), # no batching rule implemented, default doesnt work + ), # no batching rule implemented, default doesn't work skip( "new_empty_strided" ), # empty tensor data is garbage so it's hard to make comparisons with it diff --git a/test/fx/test_lazy_graph_module.py b/test/fx/test_lazy_graph_module.py index 6404b587d8707..a17bcb9151def 100644 --- a/test/fx/test_lazy_graph_module.py +++ b/test/fx/test_lazy_graph_module.py @@ -69,7 +69,7 @@ def f(x): def test_needs_recompile(self): """ - Make sure needs_recompile() return the corrent state. + Make sure needs_recompile() return the correct state. """ def f(x): @@ -141,7 +141,7 @@ def f(x): self.assertTrue(isinstance(gm2, _LazyGraphModule)) self.assertTrue(gm2._needs_recompile()) - # make_fx will cal foward method of gm. That clears the _needs_recompile() + # make_fx will cal forward method of gm. That clears the _needs_recompile() # flag. self.assertFalse(gm._needs_recompile()) @@ -175,7 +175,7 @@ def f(x): def test_save_lazy_foward(self): """ - Save the lazy forward method and call it repeatly. Make sure we + Save the lazy forward method and call it repeatedly. Make sure we don't recompile for each such call. """ diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index ebe40f471e62b..ab50b59fb96b7 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -36,9 +36,9 @@ class TestPartitionerOrder(TestCase): def test_partitioner_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) - partions = DummyPartitioner(traced_m).propose_partitions() - partion_nodes = [list(partition.nodes) for partition in partions] - node_order = [n.name for n in partion_nodes[0]] + partitions = DummyPartitioner(traced_m).propose_partitions() + partition_nodes = [list(partition.nodes) for partition in partitions] + node_order = [n.name for n in partition_nodes[0]] for _ in range(10): traced_m = torch.fx.symbolic_trace(m) new_partion = DummyPartitioner(traced_m).propose_partitions() diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py index 195a4fad2ba33..47531e15040eb 100644 --- a/test/fx/test_pass_infra.py +++ b/test/fx/test_pass_infra.py @@ -131,7 +131,7 @@ def check_bad_args(graph_module, i): def test_topological_sort(self): """ - Tests that passes are correctly ordered based on contraints. + Tests that passes are correctly ordered based on constraints. """ def pass0(x): diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py index 55dbad003db56..7796a9e4a1685 100644 --- a/test/higher_order_ops/test_invoke_quant.py +++ b/test/higher_order_ops/test_invoke_quant.py @@ -186,7 +186,7 @@ def quant_matching(match: Match, *args, **kwargs): @skipIfXpu( msg="MM Triton template fusion for XPU not work because the fusion" - " can not speedup, unskip untill #146568 fixed." + " can not speedup, unskip until #146568 fixed." ) @requires_gpu() @config.patch(prologue_fusion=True) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 052baebce337e..72daebb5f4f3f 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -1079,7 +1079,7 @@ def fn(x, y): fake_prop_count = 0 - def _mock_invoke_subgraph(mode, subgraph, identifer, *operands): + def _mock_invoke_subgraph(mode, subgraph, identifier, *operands): nonlocal fake_prop_count fake_prop_count += 1 return (operands[0].clone(),) @@ -2077,7 +2077,7 @@ def fn(x, y): # NOTE THAT THIS TEST DOES NOT REALLY WORK # We wanted one invoke_subgraph called twice, but because of - # constant_args_idx changing in the grpah, the graph equivalence fails + # constant_args_idx changing in the graph, the graph equivalence fails if not TEST_WITH_CROSSREF: self.assertExpectedInline( diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 1817b04567ab7..64fa3f14f406a 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -1,9 +1,18 @@ +aLoad +aLoads ans +aStore +aStores belows +bLoad +bLoads +bStore +bStores BU contiguities contiguity coo +deser din dout ElementE diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 646801fa9a19d..7f0888666d3af 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1843,11 +1843,11 @@ bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) { auto bLoads = NodeFinder::find(*it2); // ReadAfterWrite for (auto& aStore : aStores) { - for (auto& bLoad : bLoads) { // codespell:ignore + for (auto& bLoad : bLoads) { if (aStore->buf() == bLoad->buf()) { if (!areIndicesLoopIndependent( aStore->indices(), bLoad->indices(), outer_loop_vars)) { - if (isOverlapping(analyzer, aStore, bLoad)) { // codespell:ignore + if (isOverlapping(analyzer, aStore, bLoad)) { return true; } } From 4c8b408d164290de6602478a7acecf91aee50329 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 17 Jul 2025 14:55:15 +0800 Subject: [PATCH 179/457] [BE][1/5] fix typos in aten/ (#157550) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157550 Approved by: https://github.com/albanD ghstack dependencies: #156605, #157637 --- .lintrunner.toml | 1 - aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/DLConvertor.h | 2 +- aten/src/ATen/FunctionalInverses.cpp | 4 ++-- aten/src/ATen/FunctionalTensorWrapper.cpp | 10 +++++----- aten/src/ATen/LegacyBatchedFallback.cpp | 2 +- aten/src/ATen/LegacyVmapTransforms.h | 2 +- aten/src/ATen/MapAllocator.cpp | 2 +- aten/src/ATen/NestedTensorImpl.cpp | 2 +- aten/src/ATen/NestedTensorImpl.h | 2 +- aten/src/ATen/Parallel.h | 4 ++-- aten/src/ATen/TensorIndexing.h | 2 +- aten/src/ATen/TensorIterator.cpp | 4 ++-- aten/src/ATen/TensorIterator.h | 6 +++--- aten/src/ATen/TensorSubclassLikeUtils.h | 2 +- aten/src/ATen/TensorUtils.cpp | 4 ++-- aten/src/ATen/TracerMode.h | 2 +- aten/src/ATen/ZeroTensorFallback.cpp | 2 +- aten/src/ATen/autocast_mode.h | 2 +- aten/src/ATen/dlpack.h | 2 +- aten/src/ATen/nnapi/nnapi_bind.cpp | 2 +- aten/src/ATen/record_function.h | 2 +- aten/src/ATen/templates/RegisterDispatchKey.cpp | 2 +- aten/src/ATen/templates/TensorBody.h | 2 +- .../src/ATen/test/cpu_profiling_allocator_test.cpp | 2 +- aten/src/ATen/test/half_test.cpp | 2 +- aten/src/ATen/test/undefined_tensor_test.cpp | 2 +- aten/src/ATen/test/vec_test_all_types.cpp | 12 ++++++------ aten/src/ATen/test/vec_test_all_types.h | 6 +++--- aten/src/ATen/test/vulkan_api_test.cpp | 14 +++++++------- aten/src/ATen/test/vulkan_quantized_api_test.cpp | 14 +++++++------- aten/src/ATen/xpu/XPUEvent.h | 2 +- aten/src/README.md | 2 +- 33 files changed, 61 insertions(+), 62 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 513d7dd2d00fc..4c3e05942ce02 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1162,7 +1162,6 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/**', 'aten/src/ATen/native/**', 'aten/src/ATen/native/q*/**', 'aten/src/ATen/native/[a-pA-P]*/**', diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index af8fea2529477..3355d45eafa50 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -458,7 +458,7 @@ if(LAPACK_FOUND) # would not need this at all), some of our libraries (magma in particular) # backend to CPU BLAS/LAPACK implementations, and so it is very important # we get the *right* implementation, because even if the symbols are the - # same, LAPACK implementions may have different calling conventions. + # same, LAPACK implementations may have different calling conventions. # This caused https://github.com/pytorch/pytorch/issues/7353 # # We do NOT do this on Linux, since we just rely on torch_cpu to diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index abc996db5ab46..e9cbd94dfd724 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -4,7 +4,7 @@ #include #include -// this convertor will: +// this converter will: // 1) take a Tensor object and wrap it in the DLPack tensor // 2) take a dlpack tensor and convert it to the ATen Tensor diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 117a9eef6eb6d..123d87b304148 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -233,8 +233,8 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { - // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. - // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i + // It would be nice if this logic could be reused from autograd's split_backward(), but I don't think it can. + // For functionalization, we have only have one of the tensors from the TensorList outputted by split(), and we want to layer i // on top of the base tensor. // For autograd, we have all of the tensors outputted by split() and we just want to stack them. dim = at::maybe_wrap_dim(dim, base.dim()); diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index ff4e2b562278b..7d5e4e84e861d 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -286,11 +286,11 @@ void FunctionalTensorWrapper::storage_resize_(const c10::SymInt& new_size) { // storage resizing is severely limited: we only support resizing either to zero, or from zero bytes. TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size); // The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want - // resize_() calls to actualy emit any ops in the functional graph. + // resize_() calls to actually emit any ops in the functional graph. // How does it work? // Resizing up (old size == 0): // We do nothing in this case. - // The expection is that for the user code to be valid, the next op that should run against the current tensor "x" + // The expectation is that for the user code to be valid, the next op that should run against the current tensor "x" // will be a x.copy_(y) (or similar), that will fully overwrite the data of x. // If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call // (otherwise the eager code would be invalid), @@ -327,7 +327,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data. // // This is probably fixable in theory, but: - // - the fix would likey complicated the functionalization logic quite a bit. + // - the fix would likely complicated the functionalization logic quite a bit. // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor. // @@ -344,7 +344,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { set_sizes_and_strides(value_.sizes(), value_.strides()); refresh_numel(); // (Technically we should be guaranteed that the tensor was already contiguous, - // since it's guaranteed not to have been a view. Doesnt hurt to run though) + // since it's guaranteed not to have been a view. Doesn't hurt to run though) refresh_contiguous(); // Swapping out the storage of a tensor (aka from a resize_() call) will update the sizes and strides of the tensor, // so we need to record the fact that metadata was mutated. @@ -819,7 +819,7 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) { // This function will "functionalize" it. // That is, it will call the operator, but removing any intermediate views/mutations // that are performed inside of it. -// This is useful for LTC/XLA, which would like to re-use some of our composite kernels +// This is useful for LTC/XLA, which would like to reuse some of our composite kernels // from pytorch core but not have to worry about the view ops that they might call. // e.g. at::block_diag void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { diff --git a/aten/src/ATen/LegacyBatchedFallback.cpp b/aten/src/ATen/LegacyBatchedFallback.cpp index d44d92c239f22..f2b527302a97b 100644 --- a/aten/src/ATen/LegacyBatchedFallback.cpp +++ b/aten/src/ATen/LegacyBatchedFallback.cpp @@ -218,7 +218,7 @@ static Tensor safeStack(TensorList tensors) { // is possible for the backward function to return an undefined grad for some // grad_input for each example. In that case, we return an undefined grad. // - // It is theoretically posssible for *some* of the examples to produce an + // It is theoretically possible for *some* of the examples to produce an // undefined grad (a kernel could peek at the gradient values and return an // undefined tensor if it determines the gradient is full of zeros). We // could handle this by treating the undefined grad as a zero-filled tensor diff --git a/aten/src/ATen/LegacyVmapTransforms.h b/aten/src/ATen/LegacyVmapTransforms.h index 97729b3254e74..be6cf1b697a22 100644 --- a/aten/src/ATen/LegacyVmapTransforms.h +++ b/aten/src/ATen/LegacyVmapTransforms.h @@ -140,7 +140,7 @@ struct TORCH_API VmapPhysicalView { // mapping a physical tensor to a new logical tensor (BatchedTensor) VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; - // Maps a logical shape to a physical shape by pre-pending the batch + // Maps a logical shape to a physical shape by prepending the batch // sizes to the logical shape. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; diff --git a/aten/src/ATen/MapAllocator.cpp b/aten/src/ATen/MapAllocator.cpp index be10641aa2714..63a278050e8a7 100644 --- a/aten/src/ATen/MapAllocator.cpp +++ b/aten/src/ATen/MapAllocator.cpp @@ -299,7 +299,7 @@ MapAllocator::MapAllocator(WithFd, std::string_view filename, int fd, int flags, ::close(fd); TORCH_CHECK(false, "unable to stretch file <", filename_, "> to the right size: ", c10::utils::str_error(last_err), " (", last_err, ")"); } -/* on macOS write returns with errno 45 (Opperation not supported) when used +/* on macOS write returns with errno 45 (Operation not supported) when used * with a file descriptor obtained via shm_open */ #ifndef __APPLE__ diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index 647b2f1685d17..63bd867f90220 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -211,7 +211,7 @@ NestedTensorImpl::NestedTensorImpl( } // assume contiguous, `nested_strides` and `offsets` -// can be infered from `nested_sizes` +// can be inferred from `nested_sizes` NestedTensorImpl::NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes) diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index f40684ce0ba26..cddf37df34a52 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -32,7 +32,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { at::Tensor nested_strides, at::Tensor storage_offsets); // assume contiguous, `nested_strides` and `offsets` - // can be infered from `nested_sizes` + // can be inferred from `nested_sizes` explicit NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes); diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 917524419f9a7..b55dad02f347e 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -93,12 +93,12 @@ ident: identity for binary combination function sf. sf(ident, x) needs to return x. f: function for reduction over a chunk. f needs to be of signature scalar_t -f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy) +f(int64_t partial_begin, int64_t partial_end, scalar_t identify) sf: function to combine two partial results. sf needs to be of signature scalar_t sf(scalar_t x, scalar_t y) -For example, you might have a tensor of 10000 entires and want to sum together +For example, you might have a tensor of 10000 entries and want to sum together all the elements. Parallel_reduce with a grain_size of 2500 will then allocate an intermediate result tensor with 4 elements. Then it will execute the function "f" you provide and pass the beginning and end index of these chunks, so diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 3648862c12241..a487589833e8c 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -252,7 +252,7 @@ inline Tensor applySelect( // Note: `size >= -index` is not equivalent to `size > -1 - index` if index // is INT64_MIN For std::numeric_limits::min() result of unary // minus is undefined by the standard but in practice is equal to self. On - // the other hand, indexing wraping is valid for all negative int64_t + // the other hand, indexing wrapping is valid for all negative int64_t // values, as x[INT64_MIN] is the same as x[INT64_MAX] TORCH_CHECK_INDEX( size.sym_gt(-1 - index) diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 32f0f1e2defeb..9096cbfc68eb6 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -208,7 +208,7 @@ bool TensorIteratorConfig::is_tensor_const(size_t idx) { // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie. // // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly -// losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially +// losing the correct permutation of the first tensor if there are permuted trivial dimensions, but could potentially // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout // for example for a tensor with the sizes N1H1 // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all @@ -244,7 +244,7 @@ void TensorIteratorBase::reorder_dimensions() { // initialize perm with n-1, n-2, ..., 1, 0 std::iota(perm_.rbegin(), perm_.rend(), 0); - // Reordering dimensions changes iteraton order + // Reordering dimensions changes iteration order if (enforce_linear_iteration_) { permute_dimensions(perm_); return; diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index 0e49151969bd7..d8eebd4c06a42 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -388,7 +388,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// Return scalar value from original_tensor_base if it is defined. When /// common_dtype is Half, casting scalar input to common_dtype might overflow. - /// If the scalar is aleady given in the type of Half, then return scalar + /// If the scalar is already given in the type of Half, then return scalar /// value from tensor_base. template T original_scalar_value(int64_t arg) { @@ -502,7 +502,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// kernels bool can_use_32bit_indexing() const; - /// An "iteratable" object that recursively splits this iterator into + /// An "iterable" object that recursively splits this iterator into /// sub-iterators that can use 32-bit indexing. SplitUntil32Bit with_32bit_indexing() const; @@ -878,7 +878,7 @@ class TORCH_API TensorIteratorConfig final { // Sets the enforce_linear_iteration_ flag, which is false by default. // If true, iteration goes in the same order as a C-contiguous tensor - // is layed out in memory. i.e. last dimension iterates fastest. + // is laid out in memory. i.e. last dimension iterates fastest. // // This iteration order can be less efficient and may even prevent // vectorization. So only use if the correctness of your kernel depends on it. diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h index 49d430f6d3e41..515642a0c51d2 100644 --- a/aten/src/ATen/TensorSubclassLikeUtils.h +++ b/aten/src/ATen/TensorSubclassLikeUtils.h @@ -78,7 +78,7 @@ inline bool areAnyOptionalTensorSubclassLike( // NOTE: This function expects a scalar tensor of boolean dtype. // Eg. // Non-Composite Compliant Pattern : (t == 0).all().item() -// Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) +// Composite Compliant Pattern : is_salar_tensor_true((t == 0).all()) inline bool is_scalar_tensor_true(const Tensor& t) { TORCH_INTERNAL_ASSERT(t.dim() == 0) TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 1636bbcb6f75b..34cb5329de6a3 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -378,9 +378,9 @@ inline static std::optional computeStride_impl( (TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) && TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) { // We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not - // know if that is satisfied we keep accumalating. For example if view_numel = 1 and tensor_numel = u1, + // know if that is satisfied we keep accumulating. For example if view_numel = 1 and tensor_numel = u1, // we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop. - // Thats why we use TORCH_GUARD_OR_TRUE below. + // That's why we use TORCH_GUARD_OR_TRUE below. // we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because // if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1 diff --git a/aten/src/ATen/TracerMode.h b/aten/src/ATen/TracerMode.h index 8ba62640fe650..d0d4c93a84f53 100644 --- a/aten/src/ATen/TracerMode.h +++ b/aten/src/ATen/TracerMode.h @@ -27,7 +27,7 @@ // ops (ops being called by other ops). After the intermediate op call // finishes it's set back to the original `TracingState` object. // -// The `TracingState` obect in TLS can also be read/written via its Python +// The `TracingState` object in TLS can also be read/written via its Python // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, // which are also exposed as `TORCH_API`. // diff --git a/aten/src/ATen/ZeroTensorFallback.cpp b/aten/src/ATen/ZeroTensorFallback.cpp index 329216cf3789f..06ab82accaf27 100644 --- a/aten/src/ATen/ZeroTensorFallback.cpp +++ b/aten/src/ATen/ZeroTensorFallback.cpp @@ -95,7 +95,7 @@ namespace at { m.impl("clone", torch::CppFunction::makeFallthrough()); m.impl("dot", torch::CppFunction::makeFallthrough()); m.impl("vdot", torch::CppFunction::makeFallthrough()); - // The functions in the list below have a specific registeration in native_functions.yaml and + // The functions in the list below have a specific registration in native_functions.yaml and // do not use the fallback. // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); // m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index a222b8924bac2..655b2343d5d5c 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -377,7 +377,7 @@ Keep it simple for now by assuming only one such flag is present in the argument list. If I ever need a function with more than flag I'll figure out something else. The policy is: -If the user has explicity specified a dtype, respect it. +If the user has explicitly specified a dtype, respect it. Otherwise, set it to the autocast type. ********************************************************/ diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index 5d0234b5653e7..82c0668211188 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -199,7 +199,7 @@ typedef struct { * `byte_offset` field should be used to point to the beginning of the data. * * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp index 120c62cd4ab93..8f40ee4045681 100644 --- a/aten/src/ATen/nnapi/nnapi_bind.cpp +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -26,7 +26,7 @@ static void load_platform_library() { (void)run_once; } -// NnapiCompilation functon definitions: +// NnapiCompilation function definitions: // Could possibly call load_platform_library in constructor, but error reporting // can be complicated if the constructor is called during model loading. diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 29fbc8270a451..8ec70a1682f37 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -666,7 +666,7 @@ void record_function_with_scope_and_debug_handle( guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ } -// Helper macros to record LITE INTERPETER scope events with debug handles +// Helper macros to record LITE INTERPRETER scope events with debug handles #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ fn, debug_handle, inputs) \ RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 158277dd5d53b..39c85b00d7a1b 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -5,7 +5,7 @@ // NOTE: This condition is true for all PyTorch internal libraries, it // just excludes external projects such as torch_xla which -// re-use some of the PyTorch codegen machinery. +// reuse some of the PyTorch codegen machinery. #if defined(CAFFE2_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ defined(TORCH_HIP_BUILD_MAIN_LIB) || \ diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 050d882f42bfc..8ae2dee1ce50c 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -491,7 +491,7 @@ class TORCH_API Tensor: public TensorBase { "attribute won't be populated during autograd.backward(). If you indeed want the .grad " "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " - "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); + "instead. See github.com/pytorch/pytorch/pull/30531 for more information."); } return maybe_grad; } diff --git a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp index c390305e2051c..15220e58e2485 100644 --- a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp @@ -199,7 +199,7 @@ int main(int argc, char* argv[]) { #ifdef C10_MOBILE // Need to disable mkldnn for this test since it allocated memory - // via raw_allocate inteface which requires context pointer and raw + // via raw_allocate interface which requires context pointer and raw // pointer to be the same. Tis is not true for mobile allocator. at::globalContext().setUserEnabledMkldnn(false); #endif diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp index 900758233432d..9e594196c6925 100644 --- a/aten/src/ATen/test/half_test.cpp +++ b/aten/src/ATen/test/half_test.cpp @@ -25,7 +25,7 @@ TEST(TestHalf, Arithmetic) { ASSERT_EQ(one + one, 2); } -TEST(TestHalf, Comparisions) { +TEST(TestHalf, Comparisons) { Half zero = 0; Half one = 1; ASSERT_LT(zero, one); diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 91777c3a05c7c..ec6997fae9b05 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -9,7 +9,7 @@ using namespace at; TEST(TestUndefined, UndefinedTest) { manual_seed(123); - // mainly test ops on undefined tensors don't segfault and give a reasonable errror message. + // mainly test ops on undefined tensors don't segfault and give a reasonable error message. Tensor und; Tensor ft = ones({1}, CPU(kFloat)); diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index a9b5a70f1de91..b7b756f74ba1f 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -5,7 +5,7 @@ namespace { template class Memory : public ::testing::Test {}; template - class Arithmetics : public ::testing::Test {}; + class Arithmetic : public ::testing::Test {}; template class Comparison : public ::testing::Test {}; template @@ -92,7 +92,7 @@ namespace { using ComplexTypes = ::testing::Types; using ReducedFloatTestedTypes = ::testing::Types; TYPED_TEST_SUITE(Memory, ALLTestedTypes); - TYPED_TEST_SUITE(Arithmetics, FloatIntTestedTypes); + TYPED_TEST_SUITE(Arithmetic, FloatIntTestedTypes); TYPED_TEST_SUITE(Comparison, RealFloatIntReducedFloatTestedTypes); TYPED_TEST_SUITE(Bitwise, FloatIntTestedTypes); TYPED_TEST_SUITE(MinMax, RealFloatIntTestedTypes); @@ -691,7 +691,7 @@ namespace { AssertVectorized(NAME_INFO(DeInterleave FirstHalf), std::get<0>(cc), vec::loadu(vals)).check(true); AssertVectorized(NAME_INFO(DeInterleave SecondHalf), std::get<1>(cc), vec::loadu(vals + vec::size())).check(true); } - TYPED_TEST(Arithmetics, Plus) { + TYPED_TEST(Arithmetic, Plus) { using vec = TypeParam; using VT = ValueType; test_binary( @@ -703,7 +703,7 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_add_overflow)); } - TYPED_TEST(Arithmetics, Minus) { + TYPED_TEST(Arithmetic, Minus) { using vec = TypeParam; using VT = ValueType; test_binary( @@ -715,7 +715,7 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_sub_overflow)); } - TYPED_TEST(Arithmetics, Multiplication) { + TYPED_TEST(Arithmetic, Multiplication) { using vec = TypeParam; test_binary( NAME_INFO(mult), @@ -724,7 +724,7 @@ namespace { createDefaultBinaryTestCase(TestSeed(), false, true), RESOLVE_OVERLOAD(filter_mult_overflow)); } - TYPED_TEST(Arithmetics, Division) { + TYPED_TEST(Arithmetic, Division) { using vec = TypeParam; TestSeed seed; test_binary( diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7062a3048dfc..f7206cc340973 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -531,7 +531,7 @@ template std::enable_if_t::value, void> filter_div_ub(T& val1, T& val2) { //missing - //at least consdier zero division + //at least consider zero division auto ret = std::abs(val2); if (ret == 0) { val2 = T(1, 2); @@ -1291,7 +1291,7 @@ std::enable_if_t>::value, Complex> local_multiply(Compl T y_real = y.real(); T y_imag = y.imag(); #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR) - //check multiplication considerin swap and fma + //check multiplication considering swap and fma T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; @@ -1362,7 +1362,7 @@ std::enable_if_t>::value, Complex> local_division(Compl return Complex(rr, ii); #else /* defined(CPU_CAPABILITY_ZVECTOR) */ #if defined(CPU_CAPABILITY_VSX) - //check multiplication considerin swap and fma + //check multiplication considering swap and fma T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 1b4750b6c41e6..263918af2662c 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1232,7 +1232,7 @@ void test_matmul( } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { - // This will call at::bmm. Will crash for unknow reason. + // This will call at::bmm. Will crash for unknown reason. const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -1241,7 +1241,7 @@ TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_cpu) { - // This will call at::bmm. Will crash for unknow reason. + // This will call at::bmm. Will crash for unknown reason. const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -2004,7 +2004,7 @@ TEST_F(VulkanAPITest, conv2d_pw_prepack_bc_medium) { 1); // groups } -// The followin 2 tests failed on Meta's CI when all tests are executed. Output +// The following 2 tests failed on Meta's CI when all tests are executed. Output // has lots of nan. Cause unknown. // When this test is run alone (with gtest_filter), it passes. // The test also passes with smaller planes, see "conv2d_pw_prepack_medium". @@ -5664,7 +5664,7 @@ TEST_F(VulkanAPITest, var_2d_unbiased) { test_var({3, 5}, {1}, true, true); test_var({3, 5}, {1}, true, false); - // inpu.dim() == dim_list.size(), only keepdim == true is supported + // input.dim() == dim_list.size(), only keepdim == true is supported test_var({3, 5}, {0, 1}, true, true); } @@ -5672,7 +5672,7 @@ TEST_F(VulkanAPITest, var_2d_biased) { test_var({3, 5}, {1}, false, true); test_var({3, 5}, {1}, false, false); - // inpu.dim() == dim_list.size(), only keepdim == true is supported + // input.dim() == dim_list.size(), only keepdim == true is supported test_var({3, 5}, {0, 1}, false, true); } @@ -7142,12 +7142,12 @@ TEST_F(VulkanAPITest, clone_success) { } TEST_F(VulkanAPITest, clone_invalidinputs_exceptions) { - // Act: Vulkan supports Preserve and Contiguous memory foramts + // Act: Vulkan supports Preserve and Contiguous memory formats EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast); }, ::std::exception); - // Act: Vulkan supports Preserve and Contiguous memory foramts + // Act: Vulkan supports Preserve and Contiguous memory formats EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast3d); }, ::std::exception); diff --git a/aten/src/ATen/test/vulkan_quantized_api_test.cpp b/aten/src/ATen/test/vulkan_quantized_api_test.cpp index 650afceb887cc..2829aed94def9 100644 --- a/aten/src/ATen/test/vulkan_quantized_api_test.cpp +++ b/aten/src/ATen/test/vulkan_quantized_api_test.cpp @@ -2116,7 +2116,7 @@ std::tuple produce_inputs_for_binary_op( input2_cpu = produce_random_tensor(input2_shape); if (compute_quantization_params) { - // compute appropiate scale and zero point for inputs + // compute appropriate scale and zero point for inputs const auto in1_quant_params = compute_quant_params(input1_cpu); in1_scale = std::get<0>(in1_quant_params); in1_zero_point = std::get<1>(in1_quant_params); @@ -2287,7 +2287,7 @@ void test_quantized_binary_op( apply_cpu_quantized_binary_op(op_name, input1_cpu_deq, input2_cpu_deq); if (compute_quantization_params || random_quantization_params) { - // compute appropiate scale and zero point for output + // compute appropriate scale and zero point for output const auto out_quant_params = compute_quant_params(output_cpu); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -2540,7 +2540,7 @@ void test_quantized_conv2d( bias_cpu = produce_random_tensor(bias_shape, 1.26, 5.97, 0.59); if (compute_quantization_params) { - // compute appropiate scale and zero point for input, weight and bias + // compute appropriate scale and zero point for input, weight and bias const auto in_quant_params = compute_quant_params(input_cpu, in_dtype); in_scale = std::get<0>(in_quant_params); in_zero_point = std::get<1>(in_quant_params); @@ -2624,7 +2624,7 @@ void test_quantized_conv2d( groups); if (compute_quantization_params || random_quantization_params) { - // compute appropiate scale and zero point for output + // compute appropriate scale and zero point for output const auto out_quant_params = compute_quant_params(output_cpu, out_dtype); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -3524,7 +3524,7 @@ TEST_F(VulkanAPITest, linear_4d_large) { test_quantized_linear({9, 13, 11, 17}, {23, 17}, {23}); } -// The following code is not directly releated to quantization. We put it here +// The following code is not directly related to quantization. We put it here // since we are not able to run this test on GH's CI: for some unknown reason, // we are not able to reference symbols in the vulkan directory, hence the build // on GH fails. Moving the test here so we are still able to run it on @@ -3566,7 +3566,7 @@ TEST_F(VulkanAPITest, extract_texel_test) { // is the channel count. // We always start a new batch on a new z. Hence, when c cannot be divided by // 4, there are some undefined values in the padding area. We use -1 to - // indicate that we are not performing comparsion on those values. + // indicate that we are not performing comparison on those values. std::tuple test_cases[]{ {{0, 0, 0}, {0, hw, 2 * hw, 3 * hw}}, {{1, 0, 0}, {1, hw + 1, 2 * hw + 1, 3 * hw + 1}}, @@ -3672,7 +3672,7 @@ TEST_F(VulkanAPITest, channel_to_width_packing_test) { at::Tensor output = at::native::vulkan::ops::convert(v_output); // This tensor will be width-packed. Meaning that each texel represent - // consecutive elements along the width dimension. The differece between + // consecutive elements along the width dimension. The difference between // consecutive texels is 1. std::tuple test_cases[]{ {{0, 0, 0}, {0, 1, 2, 3}}, diff --git a/aten/src/ATen/xpu/XPUEvent.h b/aten/src/ATen/xpu/XPUEvent.h index ededd6ebf4f15..19d42aae080f1 100644 --- a/aten/src/ATen/xpu/XPUEvent.h +++ b/aten/src/ATen/xpu/XPUEvent.h @@ -12,7 +12,7 @@ namespace at::xpu { * must match the same device. * * Currently, XPUEvent does NOT support to export an inter-process event from - * another process via inter-process comunication(IPC). So it means that + * another process via inter-process communication(IPC). So it means that * inter-process communication for event handles between different processes is * not available. This could impact some applications that rely on cross-process * synchronization and communication. diff --git a/aten/src/README.md b/aten/src/README.md index 3127ed5c8c399..fa279c89d26ca 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -8,7 +8,7 @@ multiple variants of the library, summarized here: * THC = TorcH Cuda * THCS = TorcH Cuda Sparse (now defunct) * THNN = TorcH Neural Network (now defunct) -* THS = TorcH Sparse (now defunct) +* THS = TorcH Sparse (now defunct) (You'll also see these abbreviations show up in symbol names.) From f57ef62ebcd60dfd47149936d730f9b69128a88b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 17 Jul 2025 14:55:17 +0800 Subject: [PATCH 180/457] [BE][2/5] fix typos in aten/ (aten/src/ATen/native/) (#157551) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157551 Approved by: https://github.com/albanD ghstack dependencies: #156605, #157637, #157550 --- .lintrunner.toml | 1 - aten/src/ATen/native/BatchLinearAlgebra.cpp | 4 ++-- aten/src/ATen/native/DilatedMaxPool2d.cpp | 4 ++-- aten/src/ATen/native/DilatedMaxPool3d.cpp | 4 ++-- aten/src/ATen/native/DistributionTemplates.h | 4 ++-- aten/src/ATen/native/GridSampler.cpp | 10 +++++----- aten/src/ATen/native/Math.h | 6 +++--- aten/src/ATen/native/Pool.h | 4 ++-- aten/src/ATen/native/SegmentReduce.cpp | 2 +- aten/src/ATen/native/TensorAdvancedIndexing.cpp | 2 +- aten/src/ATen/native/TensorConversions.cpp | 2 +- aten/src/ATen/native/TensorFactories.cpp | 4 ++-- aten/src/ATen/native/TensorIteratorReduce.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 10 +++++----- .../native/sparse/SparseBinaryOpIntersectionCommon.h | 4 ++-- .../native/sparse/SparseBinaryOpIntersectionKernel.cpp | 2 +- aten/src/ATen/native/sparse/SparseTensor.cpp | 2 +- .../native/sparse/ValidateCompressedIndicesCommon.h | 2 +- aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h | 2 +- aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp | 2 +- aten/src/ATen/native/sparse/cuda/SparseMatMul.cu | 2 +- aten/src/ATen/native/transformers/attention.cpp | 2 +- .../native/transformers/cuda/attention_backward.cu | 10 +++++----- .../native/transformers/cuda/flash_attn/flash_api.cpp | 2 +- .../epilogue/epilogue_rescale_output.h | 4 ++-- .../cuda/mem_eff_attention/gemm/custom_mma_base.h | 2 +- .../iterators/epilogue_predicated_tile_iterator.h | 2 +- .../cuda/mem_eff_attention/kernel_backward.h | 2 +- aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 2 +- .../ATen/native/transformers/hip/aotriton_adapter.h | 2 +- .../transformers/hip/flash_attn/aot/mha_all_aot.hip | 2 +- .../hip/flash_attn/ck/mha_varlen_fwd_ck.hip | 2 +- aten/src/ATen/native/utils/ParamsHash.h | 2 +- aten/src/ATen/native/vulkan/api/Types.h | 2 +- aten/src/ATen/native/vulkan/glsl/conv2d.glsl | 2 +- aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl | 2 +- .../ATen/native/vulkan/glsl/image_to_nchw_uint.glsl | 2 +- aten/src/ATen/native/vulkan/glsl/indexing.h | 4 ++-- aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl | 2 +- aten/src/ATen/native/vulkan/ops/Tile.cpp | 2 +- tools/linter/dictionary.txt | 5 +++++ 41 files changed, 67 insertions(+), 63 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 4c3e05942ce02..707d262354313 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1162,7 +1162,6 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/src/ATen/native/**', 'aten/src/ATen/native/q*/**', 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index cfeb67bef3bd9..d323e54a95abe 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2453,7 +2453,7 @@ TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A, // geqrf requires m x n workspace input that is modified in-place // We try to use Q. If it doesn't fit, we try to use R - // If m > n and compute_q==false, it won't fit into Q or R, so we neet to create an auxiliary tensor + // If m > n and compute_q==false, it won't fit into Q or R, so we need to create an auxiliary tensor Tensor QR; if (compute_q && Q.size(-1) == n) { QR = Q; @@ -4095,7 +4095,7 @@ Tensor linalg_vander_symint( const auto n = N.value_or(shape.back()); TORCH_CHECK(n > 1, "N must be greater than 1."); - // Append cumprod of the oher 0...n-1 powers + // Append cumprod of the other 0...n-1 powers shape.push_back(n - 1); auto result = at::cumprod(x_.unsqueeze(-1).expand_symint(shape), -1); // The row of ones diff --git a/aten/src/ATen/native/DilatedMaxPool2d.cpp b/aten/src/ATen/native/DilatedMaxPool2d.cpp index 218a673d0a34d..641e9f14dd711 100644 --- a/aten/src/ATen/native/DilatedMaxPool2d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool2d.cpp @@ -54,7 +54,7 @@ bool ceil_mode) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } /* sizes */ @@ -130,7 +130,7 @@ const Tensor& indices) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } /* sizes */ diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp index 458e2c032b094..23d77cb210720 100644 --- a/aten/src/ATen/native/DilatedMaxPool3d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp @@ -63,7 +63,7 @@ void max_pool3d_with_indices_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); } const int64_t nslices = input.size(-4); @@ -158,7 +158,7 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); } const int64_t nslices = input.size(-4); diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index c6013b6fbae5f..21a15b80c9c84 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -28,13 +28,13 @@ namespace at::native::templates { // ==================================================== Random ======================================================== // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. -// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t). +// The current implementation of `random_` uses uint64_t arithmetic and casts the result to the target dtype(scalar_t). // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: // // auto actual = torch::empty({3, 3}, torch::half); // actual.random_(0, 65504); // -// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504 +// If random's uint64_t arithmetic produces 65503 as a random value after casting to torch::half it becomes 65504 // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index efdc151bf68e2..0ca8ec2a3a887 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -86,7 +86,7 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -285,7 +285,7 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -496,7 +496,7 @@ static Tensor _grid_sampler_2d_cpu_quantized( uint8_t* inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; float x = *grid_ptr_NHW; float y = grid_ptr_NHW[grid_sCoor]; @@ -599,7 +599,7 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; @@ -771,7 +771,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN; for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { - // get the corresponding input x, y co-ordinates from grid + // get the corresponding input x, y coordinates from grid const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index a372c5f0c7e54..b261da5fe54ee 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -1068,7 +1068,7 @@ inline scalar_t calc_igammac(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.4 [igam1]) - * - if x > 1.1 and x < a, using the substraction from the regularized lower + * - if x > 1.1 and x < a, using the subtraction from the regularized lower * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ @@ -1148,7 +1148,7 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.3 [igam1]) - * - if x > 1 and x > a, using the substraction from the regularized upper + * - if x > 1 and x > a, using the subtraction from the regularized upper * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ @@ -1730,7 +1730,7 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) { with the usual checks for overflow etcetera. Performance-wise, it seems to be substantially faster than either - the SLATEC DERFC function [or an erfcx function derived therefrom] + the SLATEC DERFC function [or an erfcx function derived there from] or Cody's CALERF function (from netlib.org/specfun), while retaining near machine precision in accuracy. */ diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 51d19102ad934..7f335de04b90a 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -17,7 +17,7 @@ using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& g DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel) DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel) -// averge pooling has same signature for forward and backward +// average pooling has same signature for forward and backward using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override); using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, @@ -26,7 +26,7 @@ using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel) DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel) -// averge pooling has same signature for forward and backward +// average pooling has same signature for forward and backward using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD, int64_t padW, int64_t padH, int64_t padD, bool count_include_pad, diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 2e9df75307583..2b61bcec6a828 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -480,7 +480,7 @@ REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) // Currently some computation is being duplicated across forward and backward. -// TODO: Cache indices in forward pass to re-use in backward +// TODO: Cache indices in forward pass to reuse in backward Tensor _segment_reduce_backward_kernel( const Tensor& grad, const Tensor& output, diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 67c0af9212bc7..408faea1b7644 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -475,7 +475,7 @@ static void build_index_op( TensorIteratorBase& iter, const at::native::AdvancedIndex& info, const Tensor& result) { - // 'TensorIterator' needs to own the things comming from 'info', since + // 'TensorIterator' needs to own the things coming from 'info', since // 'info' will be destroyed after the META function. TensorIteratorConfig config; // info.src is a restrided view of result diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 0fba01ee6e4ed..7df7745fc5077 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -67,7 +67,7 @@ namespace at::native { namespace { // dense_to_sparse_{csr,bsr,csc,bsc} common helpers -// Preparation fo the N-D dense -> sparse compressed conversion. +// Preparation for the N-D dense -> sparse compressed conversion. // The N-D input is converted to 3-D (single batch dim) where we check that the // product of batch dims is nonzero and for each batch the sparse matrix // contained within has the same number of non-zero elements. diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1aab4b11c9634..054cc66cf8eb3 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1367,9 +1367,9 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) { for (int64_t i = 0; i < n - 1; i++) { // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) int64_t z = generator->random() % (n - i); - scalar_t sav = r__data[i * r__stride_0]; + scalar_t save = r__data[i * r__stride_0]; r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0]; - r__data[(z + i) * r__stride_0] = sav; + r__data[(z + i) * r__stride_0] = save; } return; } diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp index fbd9ff6b2dd7a..ce2987eb251ae 100644 --- a/aten/src/ATen/native/TensorIteratorReduce.cpp +++ b/aten/src/ATen/native/TensorIteratorReduce.cpp @@ -80,7 +80,7 @@ static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) { } /// Chooses a dimension over which to parallelize. Prefers the outer-most -/// dimension thats larger than the number of available threads. +/// dimension that's larger than the number of available threads. static int find_split_dim(TensorIteratorBase& iter) { int num_threads = at::get_num_threads(); auto shape = iter.shape(); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 79b253b16a3fe..340ee49bffa8f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -384,7 +384,7 @@ Tensor& set_storage_cpu_( result.unsafeGetTensorImpl()->set_storage_offset(storage_offset); at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt; - // We can re-use this kernel for the meta device. + // We can reuse this kernel for the meta device. // We just need to make sure we don't actually try to resize the (null) // storage. at::native::resize_impl_cpu_( @@ -505,7 +505,7 @@ Tensor& set_cpu_(Tensor& result) { return result; } -// We can't re-use the cpu kernel here because we don't want to use the cpu +// We can't reuse the cpu kernel here because we don't want to use the cpu // allocator. Tensor& set_meta_(Tensor& result) { caffe2::TypeMeta dtype = result.dtype(); @@ -1904,7 +1904,7 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) { } Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) { - // If self.size() > len(reps), reps is promoted to self.size() by pre-pending + // If self.size() > len(reps), reps is promoted to self.size() by prepending // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). @@ -2428,7 +2428,7 @@ Tensor index_select_sparse_cpu( const auto dim_indices = indices[dim].contiguous(); // If nnz is smaller than size, then either indices[dim] or index gets - // sorted, then this is followed by a binary search to find interesections. + // sorted, then this is followed by a binary search to find intersections. const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple { const auto grain_size = at::internal::GRAIN_SIZE; @@ -3934,7 +3934,7 @@ Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) { quantizer->scalar_type()); } // TODO: quantized Tensor support for SymInt needs to be added but basic - // building blocs are missing for now. + // building blocks are missing for now. auto result = make_qtensor( self, C10_AS_INTARRAYREF_SLOW(sizes), diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index d7da40750ba11..805035cdd6263 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -51,10 +51,10 @@ ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) { // Similarly, an upper bound is a value at *it with the smallest index // such that *it > value if such value exists, or last if does not. // Let is_lower = true and *it < value, then we know that *it and values - // preceeding *it cannot contain a lower bound, so we adjust initial iterator range + // preceding *it cannot contain a lower bound, so we adjust initial iterator range // from [first, first + count] to [first + step + 1, first + count - (step + 1)], // where +1 skips the element at which we have just evaluated *it < value. - // Samilar logic holds when is_lower = false. + // Similar logic holds when is_lower = false. if (is_lower ? *it < value : value >= *it) { first = ++it; count -= step + 1; diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 20a44c8709399..cf854a84e7dad 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -79,7 +79,7 @@ struct CPUValueSelectionIntersectionKernel { const auto* ptr_argsort = argsort.const_data_ptr(); for (int64_t i = 0; i < n; ++i) { - // Exctract data + // Extract data auto* ptr_res_values = reinterpret_cast(ptr_res_values_bytes); const auto* ptr_lhs_values = reinterpret_cast(ptr_lhs_values_bytes); const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index b63d8ae80e50b..752365d545dee 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -730,7 +730,7 @@ static std::tuple sparse_mask_like_prepare_sparse_inp // is that these primitives might project first argument onto second one or // the other way around depending on which arguments are coalesced and which are // larger. This function prepares inputs for `sparse_mask` such that `t` is - // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it + // projected onto `mask` by sorting `t` if uncoalesced and artificially marking it // as coalesced all while `mask` is set to uncoalesced. // The result of this projectionk is going to be uncoalesced, so it is up to the // user to set the corresponding flag correctly with respect to the operations' diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h index ec4c084a39cc1..267c19561a29d 100644 --- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -242,7 +242,7 @@ void _validate_compressed_sparse_indices_kernel( // Catch integer overflow from large dimensions. Otherwise, the // invariant checks may fail with bogus exceptions or succeed with // false-positive results when int64_t typed dimensions are cast to - // index dtype that corresponds to smaller interger type such as + // index dtype that corresponds to smaller integer type such as // int32_t. { AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [cdim, dim, nnz]() { diff --git a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h index f902f1e61c5e4..530804099b6fd 100644 --- a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h +++ b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h @@ -112,7 +112,7 @@ struct LargestValuesGreedy { } }; -// We consider each rows independantly in order +// We consider each rows independently in order // This is to ensure that a row's sparsity pattern is only determined // by its values and the rows before (but never the rows after) // This enforces causality strictly diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index 582778fdc299d..c656dc71a660d 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -65,7 +65,7 @@ void _csrmm2( csrvala, /* values of the sparse matrix, size = nnz */ CUSPARSE_INDEX_32I, /* data type of row offsets index */ CUSPARSE_INDEX_32I, /* data type of col indices */ - CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col indes */ + CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col index */ cusparse_value_type /* data type of values */ )); diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 867f103ba518c..c6e3197a22a8b 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -93,7 +93,7 @@ void create_general_description_(cusparseMatDescr_t& description_) { } // csrMatrixRef is used to have a representation of a raw CSR matrix representation -// comming from `sparse_sparse_matmul_cuda_kernel` function. +// coming from `sparse_sparse_matmul_cuda_kernel` function. // Moreover this implements a RAII guard for a cusparse descriptor template struct csrMatrixRef { diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 8647a199ad8e9..7aad4309924d4 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -207,7 +207,7 @@ Tensor qkv_projection( } else { // encoder-decoder attention // TODO: is there a more efficient way to set this up? - // TODO: can we stay nested insted of using cat? Probably just make a + // TODO: can we stay nested instead of using cat? Probably just make a // NestedTensor out of the matmul results or something? auto q_kv_weight_s = at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index b702d2f8a8e70..50fa4ed523a89 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -98,14 +98,14 @@ std::tuple _flash_attention_backward( std::optional dk{std::nullopt}; std::optional dv{std::nullopt}; - // The kernel computes irregardless we will drop for this functions return + // The kernel computes regardless we will drop for this functions return Tensor grad_softmax; // Currently unused args: std::optional alibi_slopes{std::nullopt}; const float softcap = 0.0; - bool determinisitic{false}; + bool deterministic{false}; auto& ctx = at::globalContext(); if (ctx.deterministicAlgorithms()) { if (ctx.deterministicAlgorithmsWarnOnly()) { @@ -113,7 +113,7 @@ std::tuple _flash_attention_backward( "Flash Attention defaults to a non-deterministic algorithm. ", "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); } else { - determinisitic = true; + deterministic = true; } } @@ -148,7 +148,7 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, - determinisitic, + deterministic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -176,7 +176,7 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, - determinisitic, + deterministic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 9eed9b69d8bdf..854c33dec7342 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -1328,7 +1328,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h index fd7982e5f699e..7115cb07a793e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -125,7 +125,7 @@ class MemoryEfficientAttentionNormalize { FragmentSource const& source) const { assert(!isFirst); - // Convert source to interal compute numeric type + // Convert source to internal compute numeric type NumericArrayConverter source_converter; NumericArrayConverter @@ -164,7 +164,7 @@ class MemoryEfficientAttentionNormalize { const { assert(isFirst); - // Convert source to interal compute numeric type + // Convert source to internal compute numeric type NumericArrayConverter accumulator_converter; diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h index 229c59d68347a..3c3566512b45c 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h @@ -88,7 +88,7 @@ class CustomMmaBase { Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>; - /// Number of warp-level GEMM oeprations + /// Number of warp-level GEMM operations static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h index c7a9915fed6d8..e75a1b9001e02 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -68,7 +68,7 @@ namespace threadblock { /// ForwardTileIterator /// template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) typename Element_, ///< Element data type bool ScatterD = false, ///< Scatter D operand or not bool UseCUDAStore = false> diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index ae649e99c4cd8..20495a05474b0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -245,7 +245,7 @@ struct AttentionBackwardKernel { static constexpr int64_t kWarpSize = 32; // If this is true, we store and accumulate dK/dV in RF - // rather than going back to gmem everytime + // rather than going back to gmem every time static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; static_assert( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index d1101b6597a5b..4b85b2d28753a 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -395,7 +395,7 @@ bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) bool check_all_tensors_on_device(sdp_params const& params, bool debug) { // Check that all tensors are on the GPU device - // This should be handled by the stub dispatch, but whe call can_use_*_attention + // This should be handled by the stub dispatch, but we call can_use_*_attention // directly from python we need to ensure that the tensors are on cuda if (params.query.device().type() != at::DeviceType::CUDA) { if (debug) { diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index b38122248db80..aedb205e57101 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -82,7 +82,7 @@ aotriton::TensorView mk_aotensor(const at::Tensor& q, std::string_view ten { const auto strides = q.strides(); int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name + if (real_rank != Rank) { // Lazy conversion of tensor_name TORCH_CHECK(false, std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + " but is " + std::to_string(real_rank)); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 1908096e2f6fa..05523f75caa42 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -401,7 +401,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot CHECK_SHAPE(cu_seqlens_k, batch_size + 1); // AOTriton's varlen API needs input shapes be - // (1, num_heads, total sequence lenght, head dimension) + // (1, num_heads, total sequence length, head dimension) at::Tensor q_padded, k_padded, v_padded; at::Tensor out, out_padded; q_padded = q.unsqueeze(0).transpose(1, 2); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index 20ad315d3025b..ece6f29877abe 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -209,7 +209,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads const int total_q = q.size(0); const int total_k = k.size(0); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/utils/ParamsHash.h b/aten/src/ATen/native/utils/ParamsHash.h index 6b7894cb8549f..4c9d97328ad61 100644 --- a/aten/src/ATen/native/utils/ParamsHash.h +++ b/aten/src/ATen/native/utils/ParamsHash.h @@ -41,7 +41,7 @@ struct ParamsEqual { }; // Provide explicit byte-for-byte constructors to avoid uwittingly leaving -// padding bytes unitialized (e.g., when passing Params by value) +// padding bytes uninitialized (e.g., when passing Params by value) template struct ParamsWrapper { T pod; diff --git a/aten/src/ATen/native/vulkan/api/Types.h b/aten/src/ATen/native/vulkan/api/Types.h index 548703aa8a956..1202a3bd73938 100644 --- a/aten/src/ATen/native/vulkan/api/Types.h +++ b/aten/src/ATen/native/vulkan/api/Types.h @@ -71,7 +71,7 @@ inline VkFormat to_vkformat(const ScalarType t) { /* * Given a `VkFormat`, return the `ScalarType` that best represents the data - * type of invidivual elements in an image texture of the `VkFormat`. Note that + * type of individual elements in an image texture of the `VkFormat`. Note that * this mapping is different from the `to_vkformat()` function, since different * `ScalarType`s may use the same `VkFormat`. */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index d5fe3c232e440..47a2630aaafbe 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -75,7 +75,7 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight - // tensor. the x coordinate is multipled by 4 since each group of 4 channels + // tensor. the x coordinate is multiplied by 4 since each group of 4 channels // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index c4728a9bb94e9..d4188d6580599 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -39,7 +39,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * Computes a 2D pointwise convolution of a 2x2 output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel - * size is only 1x1, making it much easier to re-use loaded texels from uKernel. + * size is only 1x1, making it much easier to reuse loaded texels from uKernel. */ void main() { const ivec3 gpos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl index 8080db6120aa0..1f66a5fe19151 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl @@ -57,7 +57,7 @@ void main() { // out CxHxW plane. ivec4 c_index = pos_in_batch / uBlock.in_extents.w; - // we devide pos_in_batch by HxW, to compute the channel index + // we divide pos_in_batch by HxW, to compute the channel index ivec4 pos_in_hw = pos_in_batch % uBlock.in_extents.w; // we compute the reminder mod HxW, to find the positions in the flatten diff --git a/aten/src/ATen/native/vulkan/glsl/indexing.h b/aten/src/ATen/native/vulkan/glsl/indexing.h index 2bda5a2362405..c34ce25001ef5 100644 --- a/aten/src/ATen/native/vulkan/glsl/indexing.h +++ b/aten/src/ATen/native/vulkan/glsl/indexing.h @@ -1,12 +1,12 @@ /* - * Computes a 4D tensor co-ordinate from a linearized index + * Computes a 4D tensor coordinate from a linearized index */ uvec4 idx_to_coord(const uint idx, const uvec4 strides, const uvec4 sizes) { return ivec4(mod(idx / strides, sizes)); } /* - * Computes a linearized index from a 4D tensor co-ordinate + * Computes a linearized index from a 4D tensor coordinate */ uint coord_to_idx(const uvec4 coord, const uvec4 strides) { return int(dot(coord * strides, ivec4(1))); diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl index 0b4ee355a0642..bc13655d01e07 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -96,7 +96,7 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight - // tensor. the x coordinate is multipled by 4 since each group of 4 channels + // tensor. the x coordinate is multiplied by 4 since each group of 4 channels // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/ops/Tile.cpp b/aten/src/ATen/native/vulkan/ops/Tile.cpp index 2ea62e909119c..d39fd951106c6 100644 --- a/aten/src/ATen/native/vulkan/ops/Tile.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tile.cpp @@ -18,7 +18,7 @@ namespace { using namespace api::utils; Tensor tile(const Tensor& self, const IntArrayRef repeats) { - // If self.size() > len(reps), reps is promoted to self.size() by pre-pending + // If self.size() > len(reps), reps is promoted to self.size() by prepending // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 64fa3f14f406a..61eaeaf8600d7 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -21,11 +21,14 @@ fro froms Halfs hsa +indexT inp inps inpt inpts matA +matB +matC nd nin NotIn @@ -38,6 +41,7 @@ ot overrideable oW padD +posIn ptd rebuild rebuilt @@ -53,4 +57,5 @@ strat supercede supercedes te +tne WONT From d5af0eca8def9a4ae1af69638de3983f3bec778c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 17 Jul 2025 14:55:18 +0800 Subject: [PATCH 181/457] [BE][3/5] fix typos in aten/ (aten/src/ATen/native/) (#157552) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157552 Approved by: https://github.com/albanD ghstack dependencies: #156605, #157637, #157550, #157551 --- .lintrunner.toml | 1 - .../ATen/native/quantized/cpu/ACLUtils.cpp | 2 +- .../ATen/native/quantized/cpu/QnnpackUtils.h | 2 +- .../native/quantized/cpu/fbgemm_utils.cpp | 2 +- aten/src/ATen/native/quantized/cpu/qconv.cpp | 2 +- .../qnnpack/scripts/build-android-arm64.sh | 2 +- .../qnnpack/scripts/build-android-armv7.sh | 2 +- .../cpu/qnnpack/scripts/build-android-x86.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-arm64.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-arm64e.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-armv7.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-armv7s.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-i386.sh | 2 +- .../cpu/qnnpack/scripts/build-ios-x86_64.sh | 2 +- .../cpu/qnnpack/scripts/build-local.sh | 2 +- .../quantized/cpu/qnnpack/src/convolution.c | 2 +- .../cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S | 36 +++++++++--------- .../cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S | 8 ++-- .../cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S | 36 +++++++++--------- .../qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S | 38 +++++++++---------- .../cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S | 8 ++-- .../qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S | 10 ++--- .../q8gemm_sparse/4x4-packA-aarch32-neon.S | 12 +++--- .../4x8c1x4-dq-packedA-aarch32-neon.S | 8 ++-- .../4x8c8x1-dq-packedA-aarch32-neon.S | 4 +- .../q8gemm_sparse/8x4-packA-aarch32-neon.S | 12 +++--- .../q8gemm_sparse/8x4-packA-aarch64-neon.S | 8 ++-- .../src/q8gemm_sparse/8x4-packA-sse2.c | 6 +-- .../q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h | 4 +- .../8x8c1x4-dq-packedA-aarch64-neon.S | 16 ++++---- .../8x8c8x1-dq-packedA-aarch64-neon.S | 10 ++--- .../quantized/cpu/qnnpack/src/qnnpack/pack.h | 4 +- .../qnnpack/src/requantization/q31-scalar.c | 2 +- .../cpu/qnnpack/test/requantization.cc | 8 ++-- .../ATen/native/quantized/cudnn/Linear.cpp | 2 +- .../ATen/native/quantized/cudnn/Pooling.cpp | 2 +- tools/linter/dictionary.txt | 2 + 37 files changed, 134 insertions(+), 133 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 707d262354313..04664378d8bf8 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1162,7 +1162,6 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/src/ATen/native/q*/**', 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', 'test/**', diff --git a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp index 7108ecd64cac7..c689132c7692e 100644 --- a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp +++ b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp @@ -81,7 +81,7 @@ DynamicQuantMatmul::DynamicQuantMatmul( auto src_q_tensor_info = arm_compute::TensorInfo( arm_compute::TensorShape(weight_dim_0, m), 1, - // ACL dyanamically quantized matmuls only support (signed) int8_t + // ACL dynamically quantized matmuls only support (signed) int8_t arm_compute::DataType::QASYMM8_SIGNED, // TODO: setting the initial offset value to int8_t max instead of zero, // because ACL currently skips MatrixBReduction calculation if the diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index 36f6140953f6a..764d237e68b4c 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -456,7 +456,7 @@ make_zero_points_and_scales_tensor( uint32_t groups = 1) { const int out_ch_idx = transpose ? 1 : 0; const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); - // Add 8 to account for bufferring needed by QNNPACK. + // Add 8 to account for buffering needed by QNNPACK. const auto num_output_channels_padded = num_output_channels + kPaddingChannels; const auto qtype = weight_contig.qscheme(); std::vector weight_zp(num_output_channels_padded, 0); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index f8651377ddf93..4cf3dfe2dbaa7 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -366,7 +366,7 @@ Tensor ConvertConvWeightsToChannelLastTensor<3>( #endif // USE_FBGEMM namespace { - // This is really terrible, but couldnt figure out a better way to constexpr convert int to + // This is really terrible, but couldn't figure out a better way to constexpr convert int to // string and then perform string concatenation on/with it constexpr const char* _hack_int_to_class_name(int x) { switch(x) { diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index cd4f253d09933..8624c9ef03367 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1277,7 +1277,7 @@ at::Tensor PackedConvWeightsOnednn::apply_impl( float sum_scale = has_accum ? accum.value().q_scale() : 1.0; int32_t sum_zero_point = has_accum ? accum.value().q_zero_point() : 0; if (has_accum) { - // Just tells we have these post op, the actual value such as scale and zero point will be setted later. + // Just tells we have these post op, the actual value such as scale and zero point will be set later. op_attr = kReluFused ? ideep::attr_t::residual_with_sum_zero_point() : ideep::attr_t::fuse_sum(); const ideep::scale_t accum_scale = ideep::scale_t(1, 1.0/sum_scale); const ideep::zero_point_t accum_zero_points = ideep::zero_point_t(1, sum_zero_point); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh index 389430b043fe6..5c52f1a020f1e 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/arm64-v8a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh index 6f32950125e0b..81da44097801f 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/armeabi-v7a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh index 5f19db582fb09..747704f1edfea 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/x86 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh index d155d6f7507df..8e867f18d3f91 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/arm64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh index 985315f74a667..34a95d1944148 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64e") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/arm64e && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh index 0431c090db68f..37e57ab557fcc 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/armv7 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh index e3f3d6b76231d..2fd2732191112 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7s") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/armv7s && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh index e8952148e66ad..b51b574d8136a 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=i386") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/i386 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh index 10a58b843e2a7..a3430082e3e57 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh @@ -45,7 +45,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=x86_64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/x86_64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh index b429650c21842..ac61a4061b90c 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh @@ -27,7 +27,7 @@ CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/local && cmake ../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c index a37c53a11529e..29f5338f5c734 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c @@ -368,7 +368,7 @@ static enum pytorch_qnnp_status pytorch_qnnp_create_convolution_ndhwc_q8( case pytorch_qnnp_ukernel_type_xzp_gemm: { // TODO: XZP kernels won't be supporting per channel quantization. // For now we dont use XZP kernels anywhere. Probably deprecate it for now - // and ressurrect later if needed. + // and resurrect later if needed. const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S index 75eab4a1c305c..ac06fa5973eca 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S @@ -20,28 +20,28 @@ # Args passed via stack. # TOS -# |-----------| -# |a | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r11 | 64 -# |a | 96 -# |w | 100 -# |c | 104 -# |c_stride | 108 -# |out ch indx| 112 -# |params | 116 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r11 | 64 +# |a | 96 +# |w | 100 +# |c | 104 +# |c_stride | 108 +# |out ch index| 112 +# |params | 116 +# |------------| # # void pytorch_q8conv_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S index 95d0a2ca8ebaa..1653b46e2d374 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S @@ -23,10 +23,10 @@ # Args passed via stack. # TOS -# |-----------| -# |out ch indx| 0 -# |params | 8 -# |-----------| +# |------------| +# |out ch index| 0 +# |params | 8 +# |------------| # void pytorch_q8conv_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S index 8fbea6498dcef..f18605124356e 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S @@ -20,28 +20,28 @@ # Args passed via stack. # TOS -# |-----------| -# |a_stride | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r9 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r9 | 64 -# |a_stride | 88 -# |w | 92 -# |c | 96 -# |c_stride | 100 -# |out ch indx| 104 -# |params | 108 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r9 | 64 +# |a_stride | 88 +# |w | 92 +# |c | 96 +# |c_stride | 100 +# |out ch index| 104 +# |params | 108 +# |------------| # # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S index de564d9d3d5aa..c964bf2be7c44 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S @@ -33,29 +33,29 @@ # Args passed via stack. # TOS -# |-----------| -# |a_stride | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r7 | 64 -# |a_stride | 80 -# |w | 84 -# |b | 88 -# |c | 92 -# |c_stride | 96 -# |out ch indx| 100 -# |params | 104 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r7 | 64 +# |a_stride | 80 +# |w | 84 +# |b | 88 +# |c | 92 +# |c_stride | 96 +# |out ch index| 100 +# |params | 104 +# |------------| # # void pytorch_q8gemm_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S index 52913d7528617..51866fd3b1ed1 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S @@ -22,10 +22,10 @@ # Args passed via stack. # TOS -# |-----------| -# |out ch indx| 0 -# |params | 8 -# |-----------| +# |------------| +# |out ch index| 0 +# |params | 8 +# |------------| # void pytorch_q8gemm_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S index b8bde0200687a..63f667b04a283 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S @@ -14,11 +14,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S index f1dd0a2cc0523..4583e50046d69 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S @@ -32,7 +32,7 @@ # # Packed A format. -# 4kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 4kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +53,7 @@ # This locality helps in loading 8kx4m blocks of activations # Note when M is not multiple of 4, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # Also note that this packing is same as taking for 4x1 pattern. # This is because all the adjacent k's are laid next to each other @@ -109,7 +109,7 @@ k_loop: VLD1.8 {d2}, [r6]! VLD1.8 {d3}, [r7]! - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # A matrix # -------------------------------- # | | @@ -155,7 +155,7 @@ k_loop: VTRN.32 d2, d3 VSWP d1, d2 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 VST1.8 {q0}, [r2]! VST1.8 {q1}, [r2]! @@ -172,7 +172,7 @@ k_loop: VLD1.32 {d2[]}, [r6] VLD1.32 {d3[]}, [r7] - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # _d{0-3} are arm neon vector registers # va0 = _d0 = a0 a1 a2 a3 # va1 = _d1 = b0 b1 b2 b3 @@ -218,7 +218,7 @@ k_loop: VEXT.8 d0, d0, d1, #4 VEXT.8 d1, d2, d3, #4 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 VST1.8 {q0}, [r2] .p2align 4 diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S index 5b796bb2563c8..d7a3aa6eaaf74 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S @@ -46,7 +46,7 @@ # |b | 12 # |c | 16 # |c_stride | 20 -# |out ch indx | 24 +# |out ch index | 24 # |params | 28 # |----------------| # @@ -61,7 +61,7 @@ # |b | 108 # |c | 112 # |c_stride | 116 -# |out ch indx | 120 +# |out ch index | 120 # |params | 124 # |----------------| # @@ -101,7 +101,7 @@ /* Add output_channel_index to the b_zero_point pointer */ ;\ ADD r4, r4, r5 ;\ ;\ - /* We enter the loop if r1 is atleast 1. */ ;\ + /* We enter the loop if r1 is at least 1. */ ;\ /* r1 = r1 - 1 will happen in the epilogue */ ;\ /* of the loop */ ;\ CMP r1, 1 ;\ @@ -222,7 +222,7 @@ /* Thus we will load accumulators back in q0, q1, q2, q3, q4, q5, q6, q7 */ ;\ /* When nr < 4, extra q values will be fetched from stack which may overlap */ ;\ /* with other parts of stack storing local variables. To avoid that we just */ ;\ - /* create a buffer of 128 bytes inbetween to make sure pointer increment */ ;\ + /* create a buffer of 128 bytes in between to make sure pointer increment */ ;\ /* never produces address that is beyond the stack frame of this function. */ ;\ SUB r9, sp, 140 ;\ /* Each iteration produce 4 values each of 4 bytes */ ;\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S index dd829f80e3732..37db2adcad069 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S @@ -46,7 +46,7 @@ # |b | 12 # |c | 16 # |c_stride | 20 -# |out ch indx | 24 +# |out ch index | 24 # |params | 28 # |----------------| # @@ -61,7 +61,7 @@ # |b | 108 # |c | 112 # |c_stride | 116 -# |out ch indx | 120 +# |out ch index | 120 # |params | 124 # |----------------| # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S index bff19de739b10..a5a91b9cb64f7 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S @@ -32,7 +32,7 @@ # # Packed A format. -# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +53,7 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -125,7 +125,7 @@ k_loop: VLD1.8 {d6}, [r10]! VLD1.8 {d7}, [r11]! - # Now we have 8x8 block of values that we will tranpose + # Now we have 8x8 block of values that we will transpose # A matrix # -------------------------------- # | | @@ -189,7 +189,7 @@ k_loop: VTRN.32 q0, q2 VTRN.32 q1, q3 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! @@ -213,7 +213,7 @@ k_loop: VLD1.32 {d6[]}, [r7] VLD1.32 {d7[]}, [r11] - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # _d{0-3} are arm neon vector registers # va04 = _d0 = a0 a1 a2 a3 e0 e1 e2 e3 # va15 = _d1 = b0 b1 b2 b3 f0 f1 f2 f3 @@ -260,7 +260,7 @@ k_loop: VTRN.16 d0, d2 VTRN.16 d1, d3 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S index 4cd788cf583bb..b1f8fe719ca44 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S @@ -9,7 +9,7 @@ #include # Packed A format. -# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -30,7 +30,7 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -93,7 +93,7 @@ k_loop: LD1 {v3.d}[0], [x7], 8 LD1 {v3.d}[1], [x11], 8 - # Now we have 8x8 block of values that we will tranpose + # Now we have 8x8 block of values that we will transpose # A matrix # ------------------------ # | | @@ -180,7 +180,7 @@ k_loop: LD1 {v3.s}[0], [x7] LD1 {v3.s}[1], [x11] - # Now we have 8x4 block of values that we will tranpose + # Now we have 8x4 block of values that we will transpose # A matrix # ---------------------------- # | | diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c index 4b0dd46fd4cf0..df707d3d800ea 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c @@ -14,7 +14,7 @@ #include "8x4c1x4-packed-sse2.h" // This is a super slow kernel in that it does not use intrinsics to -// tranpose. Since this is for x86 we are not optimizing it. +// transpose. Since this is for x86 we are not optimizing it. // For ARM this will be optimized. void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( const size_t mr, @@ -24,7 +24,7 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( uint8_t* a_packed) { // Packed A format. - // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. + // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -45,7 +45,7 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. - // This wil be taken care by just copying the appropriate valid data + // This will be taken care by just copying the appropriate valid data // Note that parts of A that are not filled are: // Remainder of M blocks. So some m values are random. This is ok diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h index 5503d67181722..ef771b4187b82 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h @@ -47,7 +47,7 @@ void KERNEL_NAME( const __m128i vzero = _mm_setzero_si128(); // Packed A format. - // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. + // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -68,7 +68,7 @@ void KERNEL_NAME( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. - // This wil be taken care by just copying the appropriate valid data + // This will be taken care by just copying the appropriate valid data __m128i vacc_low[4]; __m128i vacc_high[4]; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S index aca408e89757e..8af5c417da31f 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S @@ -42,11 +42,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, @@ -234,7 +234,7 @@ /* v16, v17, v18, v19, v20, v21, v22, v23 */ XX\ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ XX\ /* with other parts of stack storing local variables. To avoid that we just */ XX\ - /* create a buffer of 256 bytes inbetween to make sure pointer increment */ XX\ + /* create a buffer of 256 bytes in between to make sure pointer increment */ XX\ /* never produces address that is beyond the stack frame of this function. */ XX\ SUB x9, sp, 320 XX\ /* Each iteration produce 8 values each of 4 bytes */ XX\ @@ -287,7 +287,7 @@ LD1 {v22.4s}, [x9], 16 XX\ LD1 {v23.4s}, [x9] XX\ XX\ - /* We can tranpose one 4x4 block using macro */ XX\ + /* We can transpose one 4x4 block using macro */ XX\ /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ XX\ /* After this we have */ XX\ /* v8 : x00, x01, x02, x03 */ XX\ @@ -302,7 +302,7 @@ /* v20 : x24, x25, x26, x27 */ XX\ /* v22 : x34, x35, x36, x37 */ XX\ /* Similarly we can transpose other two 4x4 blocks and we get */ XX\ - /* tranposed 8x8 */ XX\ + /* transposed 8x8 */ XX\ XX\ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 XX\ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 XX\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S index 2ba033c57c835..58602beb030d1 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S @@ -31,11 +31,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h index 14ea256124856..14365d1ab3ddc 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h @@ -238,7 +238,7 @@ static inline void pytorch_pack_q8conv_wrq( } } if (kzp != 0) { - // This part fills the packed wights with zero points for output channels + // This part fills the packed weights with zero points for output channels // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); @@ -360,7 +360,7 @@ static inline void pytorch_pack_q8deconv_wrq( } } if (kzp != 0) { - // This part fills the packed wights with zero points for output channels + // This part fills the packed weights with zero points for output channels // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c index e86130f2ccb61..74961b51ff638 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c @@ -93,7 +93,7 @@ void pytorch_qnnp_requantize_q31__scalar( * overflow is possible only when input is positive, and even when addition * of a rounding constant overflows 32-bit signed integer, it still doesn't * overflow 32-bit unsigned integer. Thus, in case of signed overflow, we - * can compute the result using unsigned arithmetics, specifically using + * can compute the result using unsigned arithmetic, specifically using * logical shift right instead of arithmetic shift right. * 3. Performs arithmetic shift as is, which will produce division result * rounded down. Then compute remainder of this division by a power of 2, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc index a837974dd9fc0..f535e4b99ed76 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc @@ -17,7 +17,7 @@ #include "requantization-tester.h" /* - * Precise scalar implementation using unsigned 32-bit arithmetics. + * Precise scalar implementation using unsigned 32-bit arithmetic. */ TEST(PRECISE__SCALAR_UNSIGNED32, exact_divide_by_po2) { @@ -83,7 +83,7 @@ TEST(PRECISE__SCALAR_UNSIGNED32, random_cases) { } /* - * Precise scalar implementation using unsigned 64-bit arithmetics. + * Precise scalar implementation using unsigned 64-bit arithmetic. */ TEST(PRECISE__SCALAR_UNSIGNED64, exact_divide_by_po2) { @@ -149,7 +149,7 @@ TEST(PRECISE__SCALAR_UNSIGNED64, random_cases) { } /* - * Precise scalar implementation using signed 64-bit arithmetics. + * Precise scalar implementation using signed 64-bit arithmetic. */ TEST(PRECISE__SCALAR_SIGNED64, exact_divide_by_po2) { @@ -302,7 +302,7 @@ TEST(GEMMLOWP__SCALAR, random_cases) { } /* - * Precise PSIMD implementation using unsigned 32-bit arithmetics. + * Precise PSIMD implementation using unsigned 32-bit arithmetic. */ TEST(PRECISE__PSIMD, exact_divide_by_po2) { diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index ea776fdf450f1..230850998fda1 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -171,7 +171,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp return; } - // linear_op computes act_int8 * tranpose(w_int8) (matrix multiplication) + // linear_op computes act_int8 * transpose(w_int8) (matrix multiplication) // where act_int8 and w_int8 are the input and weight variables, resp. // output is a fp32 tensor auto linear_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index ba2cc9592d6cf..7fe44de11e54c 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -54,7 +54,7 @@ void check_maxpool2d_params( Tensor adaptive_avg_pool2d_quantized_cuda( const at::Tensor& input, IntArrayRef output_size) { -// TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn +// TODO: re-enable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 61eaeaf8600d7..706881a8f10f6 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -22,6 +22,7 @@ froms Halfs hsa indexT +inH inp inps inpt @@ -57,5 +58,6 @@ strat supercede supercedes te +THW tne WONT From 7892f5a007859ae02b9cac5441cd51f37147ef04 Mon Sep 17 00:00:00 2001 From: David Berard Date: Wed, 16 Jul 2025 18:00:28 -0700 Subject: [PATCH 182/457] [inductor][triton] Update HAS_WARP_SPEC to check triton.Config params. Update Triton Hash to top of release/3.4.x stack (#158459) Update triton commit hash to `11ec6354315768a85da41032535e3b7b99c5f706`, which is the new release/3.4.x branch in triton-lang/triton. Also, update HAS_WARP_SPEC handling: In triton 3.4, warp spec will have a different interface: num_consumer_groups will be determined automatically by the compiler. This breaks the current Inductor integration, so for now, update HAS_WARP_SPEC to check whether triton.Config takes num_consumer_groups and num_buffers_warp_spec as parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158459 Approved by: https://github.com/atalman --- .ci/docker/ci_commit_pins/triton.txt | 2 +- test/inductor/test_static_cuda_launcher.py | 35 +--------------------- torch/_inductor/runtime/triton_compat.py | 13 +++++++- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 568756a804f07..6dc1c44507ebd 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -ae848267bebc65c6181e8cc5e64a6357d2679260 +11ec6354315768a85da41032535e3b7b99c5f706 diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index c1af125eb6bc6..2ce294ed0ff55 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -2,7 +2,6 @@ import os import random import tempfile -import unittest from unittest import mock import torch @@ -13,9 +12,8 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda -from torch.torch_version import TorchVersion @requires_cuda @@ -141,37 +139,6 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - # TODO: floats don't work properly, triton seems to think they're all tl.float32 - # despite type annotations. - # There's also not really a good way for me to make a float16 in python... - @skipIfRocm - @unittest.skipIf(IS_FBCODE, "Not working in fbcode") - def test_floats(self): - @triton.jit - def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64): - x = tl.load(arg0) - y = arg1 + arg2 + arg3 - tl.store(arg0, x + y) - - arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") - - args = (arg0, 1.0, 1.0, 1.0) - - compiled_kernel = floats[1,](*args) - launcher = self._make_launcher(compiled_kernel) - if TorchVersion(triton.__version__) >= TorchVersion("3.4.0"): - self.assertEqual(launcher.arg_tys, "Offd") - else: - self.assertEqual(launcher.arg_tys, "Offf") - # TODO this line fails on Triton 3.4.0 (https://github.com/triton-lang/triton/issues/6176) - # Add the check back when this is fixed in Triton - # self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda")) - new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") - device_interface = get_interface_for_device("cuda") - stream = device_interface.get_raw_stream(device_interface.current_device()) - launcher.run(1, 1, 1, stream, new_arg0, 1.0, 1.0, 1.0) - self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_basic_1arg(self): @triton.jit diff --git a/torch/_inductor/runtime/triton_compat.py b/torch/_inductor/runtime/triton_compat.py index 877f72b50c550..645e0f4c8903d 100644 --- a/torch/_inductor/runtime/triton_compat.py +++ b/torch/_inductor/runtime/triton_compat.py @@ -69,7 +69,18 @@ def GPUTarget( def _log2(x: Any) -> Any: raise NotImplementedError - HAS_WARP_SPEC = hasattr(tl, "async_task") + def _triton_config_has(param_name: str) -> bool: + if not hasattr(triton, "Config"): + return False + if not hasattr(triton.Config, "__init__"): + return False + return param_name in inspect.signature(triton.Config.__init__).parameters + + HAS_WARP_SPEC = ( + hasattr(tl, "async_task") + and _triton_config_has("num_consumer_groups") + and _triton_config_has("num_buffers_warp_spec") + ) try: from triton import knobs From 38c04415a9440d9e5348be34f7bd71a12ed58af8 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Thu, 17 Jul 2025 13:05:06 +0000 Subject: [PATCH 183/457] [oss][hf][bug fix] Remove buggy consolidation logic (#158380) Summary: I tried to add some logic that could optimize for the non-row wise sharded case and do it more efficiently, but this has some bugs, so removing it for now and will find a better algorithm for the non-row wise sharded case to find the maximum number of bytes that we can write at a time. Test Plan: ensure tests pass Rollback Plan: Differential Revision: D78366701 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158380 Approved by: https://github.com/Saiteja64 --- .../checkpoint/_consolidate_hf_safetensors.py | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 86630903a9519..dc988e999c4ed 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -595,35 +595,6 @@ def _write_sub_tensor_to_file_optimized( ) return - # Check for fully contiguous chunk - expected_chunk_size = math.prod(sub_tensor_shape) * element_size - - if len(sub_tensor_bytes) == expected_chunk_size: - # Calculate if the chunk maps to a contiguous region in the tensor - tensor_strides = [1] - for i in range(len(tensor_shape) - 1, 0, -1): - tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) - - # Check if chunk represents a contiguous slice - chunk_start_pos = sum( - offset * stride - for offset, stride in zip(sub_tensor_offsets, tensor_strides) - ) - - # For simple contiguous cases, use direct copy - if all( - offset + size <= dim - for offset, size, dim in zip( - sub_tensor_offsets, sub_tensor_shape, tensor_shape - ) - ): - tensor_start_byte = output_start_byte + chunk_start_pos * element_size - - with fs.open(output_file_path, "r+b") as out_f: - out_f.seek(tensor_start_byte) - out_f.write(sub_tensor_bytes) - return - # Fall back to the original implementation for complex patterns _write_sub_tensor_to_file( fs, From 58c7cf9ede6311da5533dbcaf238a912176a6a85 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Wed, 16 Jul 2025 21:57:27 -0700 Subject: [PATCH 184/457] Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (#158537) Fixes #158376 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158537 Approved by: https://github.com/atalman --- aten/src/ATen/ScalarOps.cpp | 24 +++++++++++++++++++++++- test/dynamo/test_misc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 693fb46e639f2..20c9772c7003b 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -8,7 +8,29 @@ namespace at { namespace { template inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { - auto value = value_scalar.to(); + scalar_t value{}; + + if constexpr (std::is_floating_point_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // relaxed float cast: allow inf similar to the torch.tensor constructor + // + // without this, we had the following divergence: + // torch.tensor(1123581321.0, dtype=torch.float16) + // => tensor(inf, dtype=torch.float16) + // torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16) + // => RuntimeError: value cannot be converted to type at::Half without overflow + + value = static_cast(value_scalar.to()); + } else { + value = value_scalar.to(); + } + scalar_t* dptr = static_cast(self.data_ptr()); *dptr = value; } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 632ebdc39278a..b8d759c66e302 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12962,6 +12962,38 @@ def f(actions, n_act, epsilon=0.1): y = torch.tensor(5) f(x, y) + def test_dynamic_float_scalar_tensor_coersion(self): + # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 + class Foo: + def __init__(self): + self.config = type( + "Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6} + ) + + @torch.compile(fullgraph=True) + def forward(self, input): + outputs = torch.where( + torch.abs(input - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=input.dtype, device=input.device + ), + torch.tensor( + self.config.pad_val + 1, dtype=input.dtype, device=input.device + ), + ) + return outputs + + foo = Foo() + inputs = torch.randn(3, 4) + result = foo.forward(inputs) + + original_pad_val = foo.config.pad_val + foo.config.pad_val += 1.0 + result2 = foo.forward(inputs) + + # Previously would crash with: + # RuntimeError: value cannot be converted to type at::Half without overflow + devices = ("cuda", "hpu", "xpu") instantiate_device_type_tests( From a04bd11895f7523f8b210dc42bbc064cb2ca06e8 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 17 Jul 2025 14:40:34 +0000 Subject: [PATCH 185/457] [AOTI] Use format_consts_to_cpp on Windows. (#158543) `format_consts_to_asm` is not supported on Windows, force use `format_consts_to_cpp` on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158543 Approved by: https://github.com/desertfire --- torch/_inductor/codecache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 8c95847d87f42..33766c541f77d 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1804,6 +1804,10 @@ def _compile_consts(consts: bytes, platform: str) -> str: elif platform == "darwin": section_attr = "__DATA,__data" symbol_prefix = "_" + elif platform == "win32": + symbol_prefix = "" + # ASM build is not supported on Windows, force use CPP build. + use_asm_build = False else: raise RuntimeError(f"Unsupported platform: {platform}") From da4c7b4cedbaaf10754cca34cf5b052d9e880e6a Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 17 Jul 2025 14:44:29 +0000 Subject: [PATCH 186/457] [AOTI] align signature to model_base.h (#158554) Remove `const` keyword, align its signature to `model_base.h` https://github.com/pytorch/pytorch/blob/eeda1a75ace75ce8a6763050fb91d236a6d3287b/torch/csrc/inductor/aoti_runtime/model_base.h#L51-L53 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158554 Approved by: https://github.com/desertfire --- torch/_inductor/codecache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 33766c541f77d..75e129437065e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1856,7 +1856,7 @@ def format_consts_to_cpp( ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n""" const_cpp = asan_attr const_cpp += f"alignas({align_bytes}) extern " - const_cpp += f"const unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" + const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" count_bytes = 0 for c in consts: const_cpp += f"{c}, " @@ -1864,7 +1864,7 @@ def format_consts_to_cpp( if count_bytes % 16 == 0: const_cpp += "\t\n" const_cpp += "};\t\n" - const_cpp += f"alignas({align_bytes}) extern const unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" + const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" return const_cpp, "cpp" if use_asm_build: From 288bf54a23a49dd3b765b4e1c7313c706b46a08a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 14:55:28 +0000 Subject: [PATCH 187/457] Revert "Move off of deprecated API in 2.9 (#158527)" This reverts commit 9636e2cfd3e995ef977f670ad47e8e895296d992. Reverted https://github.com/pytorch/pytorch/pull/158527 on behalf of https://github.com/albanD due to breaks trunk ([comment](https://github.com/pytorch/pytorch/pull/158527#issuecomment-3084385585)) --- torch/_inductor/kernel/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index aed4a03edd186..9a7507631cc49 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -138,7 +138,7 @@ def maybe_realize(args: list[Optional[IRNode]]): def get_float32_precision(): if ( - torch.backends.cuda.matmul.fp32_precision == "ieee" + torch.get_float32_matmul_precision() == "highest" or torch.version.hip or torch.mtia.is_available() ): From 813c76b98d5bffbffb087502c4f02a043b924d59 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 15:06:43 +0000 Subject: [PATCH 188/457] Revert "Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (#158537)" This reverts commit 58c7cf9ede6311da5533dbcaf238a912176a6a85. Reverted https://github.com/pytorch/pytorch/pull/158537 on behalf of https://github.com/albanD due to This broke C++ tests ([comment](https://github.com/pytorch/pytorch/pull/158537#issuecomment-3084425920)) --- aten/src/ATen/ScalarOps.cpp | 24 +----------------------- test/dynamo/test_misc.py | 32 -------------------------------- 2 files changed, 1 insertion(+), 55 deletions(-) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 20c9772c7003b..693fb46e639f2 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -8,29 +8,7 @@ namespace at { namespace { template inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { - scalar_t value{}; - - if constexpr (std::is_floating_point_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v) { - // relaxed float cast: allow inf similar to the torch.tensor constructor - // - // without this, we had the following divergence: - // torch.tensor(1123581321.0, dtype=torch.float16) - // => tensor(inf, dtype=torch.float16) - // torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16) - // => RuntimeError: value cannot be converted to type at::Half without overflow - - value = static_cast(value_scalar.to()); - } else { - value = value_scalar.to(); - } - + auto value = value_scalar.to(); scalar_t* dptr = static_cast(self.data_ptr()); *dptr = value; } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b8d759c66e302..632ebdc39278a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12962,38 +12962,6 @@ def f(actions, n_act, epsilon=0.1): y = torch.tensor(5) f(x, y) - def test_dynamic_float_scalar_tensor_coersion(self): - # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 - class Foo: - def __init__(self): - self.config = type( - "Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6} - ) - - @torch.compile(fullgraph=True) - def forward(self, input): - outputs = torch.where( - torch.abs(input - self.config.pad_val) < self.config.tolerance, - torch.tensor( - self.config.pad_val, dtype=input.dtype, device=input.device - ), - torch.tensor( - self.config.pad_val + 1, dtype=input.dtype, device=input.device - ), - ) - return outputs - - foo = Foo() - inputs = torch.randn(3, 4) - result = foo.forward(inputs) - - original_pad_val = foo.config.pad_val - foo.config.pad_val += 1.0 - result2 = foo.forward(inputs) - - # Previously would crash with: - # RuntimeError: value cannot be converted to type at::Half without overflow - devices = ("cuda", "hpu", "xpu") instantiate_device_type_tests( From 2ecf083b7247f265a03ec296ba9d7b795f035118 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 17 Jul 2025 00:22:45 -0700 Subject: [PATCH 189/457] Add torch compile force disable caches alias (#158072) Bunch of people keep thinking current alias only disables inductor cache because it has the name inductor in it. lets globalize the name Pull Request resolved: https://github.com/pytorch/pytorch/pull/158072 Approved by: https://github.com/ezyang --- docs/source/torch.compiler_troubleshooting_old.md | 2 +- torch/_dynamo/pgo.py | 6 +++--- torch/_functorch/_aot_autograd/autograd_cache.py | 4 ++-- torch/_inductor/config.py | 8 ++------ torch/compiler/config.py | 12 ++++++++++++ 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md index 03555d74e817c..ef13fc1772374 100644 --- a/docs/source/torch.compiler_troubleshooting_old.md +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -717,5 +717,5 @@ backtrace is slow and very spammy so it is not included by default with extended In order to measure the cold start compilation time or debug a cache corruption, it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set -`torch._inductor.config.force_disable_caches = True` which will override any +`torch.compiler.config.force_disable_caches = True` which will override any other caching config option and disable all compile time caching. diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 9bdec2df05c26..403187bc6bde8 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -521,9 +521,9 @@ def process_automatic_dynamic( def get_cache_key() -> Optional[str]: # TODO: info versions of these logs that log only once - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: warn_once( - "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" + "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" ) return None @@ -566,7 +566,7 @@ def code_state_path(cache_key: str) -> Optional[str]: def should_use_remote_dynamo_pgo_cache() -> bool: - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index e66ffefe0a00c..c6a4e11ce81d3 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -95,7 +95,7 @@ class FXGraphCacheMiss(BypassAOTAutogradCache): def should_use_remote_autograd_cache(): - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False if config.enable_remote_autograd_cache is not None: return config.enable_remote_autograd_cache @@ -116,7 +116,7 @@ def should_use_remote_autograd_cache(): def should_use_local_autograd_cache(): - if torch._inductor.config.force_disable_caches: + if torch.compiler.config.force_disable_caches: return False return config.enable_autograd_cache diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 6e77283aacf2e..f4c54e2812674 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -138,12 +138,8 @@ def prologue_fusion_enabled() -> bool: # None: Not set -- Off for OSS, JustKnobs based for internal bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() -# Force disabled all inductor level caching -- This will override any other caching flag -force_disable_caches: bool = Config( - justknob="pytorch/remote_cache:force_disable_caches", - env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES", - default=False, -) +# See torch.compiler.force_disable_caches +force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches") # Unsafe way to skip dynamic shape guards to get faster cache load unsafe_skip_cache_dynamic_shape_guards: bool = False diff --git a/torch/compiler/config.py b/torch/compiler/config.py index f9ec226c25489..4009f04e4a0ae 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -66,6 +66,18 @@ A common use case for such a tag is to break caches. """ +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force=[ + "TORCHINDUCTOR_FORCE_DISABLE_CACHES", + "TORCH_COMPILE_FORCE_DISABLE_CACHES", + ], + default=False, +) +""" +Force disables all caching -- This will take precedence over and override any other caching flag +""" + dynamic_sources: str = Config( env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" ) From 16b21fa8b288140e5067d63e46f670aca495b4cd Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 17 Jul 2025 15:43:20 +0000 Subject: [PATCH 190/457] [AOTI] skip ld and objcopy on Windows. (#158545) Skip `ld` and `objcopy` on Windows. They are not support on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158545 Approved by: https://github.com/desertfire --- torch/_inductor/codecache.py | 68 +++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 75e129437065e..c8b23aded15c2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2171,40 +2171,44 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o = [] asm_files = [] - ld, objcopy = get_ld_and_objcopy(use_relative_path) - for kernel_name, value in CudaKernelParamCache.cache.items(): - if asm_file := value["asm"]: - asm_files.append(asm_file) - - cubin_file = value[get_cpp_wrapper_cubin_path_name()] - if config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda": - current_arch = _nvcc_arch_as_compile_option() - cmd = ( - f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " - # Triton only allows generating PTX version as same as the current arch - f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " - # Include SASS for the current specific arch - f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " - ) - try: - subprocess.run( - cmd.split(), - capture_output=True, - text=True, - check=True, - ) - except subprocess.CalledProcessError as e: - print( - f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", - file=sys.stderr, + if not _IS_WINDOWS: + ld, objcopy = get_ld_and_objcopy(use_relative_path) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if ( + config.aot_inductor.emit_multi_arch_kernel + and device_type == "cuda" + ): + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " ) - raise + try: + subprocess.run( + cmd.split(), + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise - if config.aot_inductor.embed_kernel_binary: - # Embed cubin files into model.so using objcopy - cubins_o.append( - convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) - ) + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy + cubins_o.append( + convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + ) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) so_build_options = CppTorchDeviceOptions( From 23550ab735eee1b9cc90609788dc64ccfb242af2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 16:20:02 +0000 Subject: [PATCH 191/457] Revert "DDE-Free select with unbacked index. (#157605)" This reverts commit 79d7c754ab8ae0e5c3a614521632d2cfbfa0fdba. Reverted https://github.com/pytorch/pytorch/pull/157605 on behalf of https://github.com/laithsakka due to fail pr time benchmarks ([comment](https://github.com/pytorch/pytorch/pull/157605#issuecomment-3084663020)) --- test/export/test_export.py | 22 ------ test/test_dynamic_shapes.py | 82 --------------------- torch/_export/passes/_node_metadata_hook.py | 1 - torch/_inductor/codegen/cpp_wrapper_cpu.py | 14 ---- torch/_inductor/codegen/wrapper.py | 8 -- torch/_inductor/dependencies.py | 35 --------- torch/_inductor/graph.py | 4 +- torch/_inductor/ir.py | 61 ++------------- torch/_inductor/lowering.py | 69 +++-------------- torch/_inductor/scheduler.py | 3 +- torch/_inductor/utils.py | 16 +--- torch/_meta_registrations.py | 33 +++++++++ torch/_subclasses/fake_impls.py | 42 ----------- torch/fx/experimental/symbolic_shapes.py | 2 - torch/fx/passes/runtime_assert.py | 10 --- 15 files changed, 53 insertions(+), 349 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index d1cecb55329c4..dea000556960d 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15782,28 +15782,6 @@ def forward(self, x, mask): ignore_empty_lines=True, ) - def test_unbacked_select_index(self): - class MyModel(torch.nn.Module): - def forward(self, x, y): - u0 = y.item() - return x.select(0, u0) - - example_inputs = ( - torch.randn((3, 3), dtype=torch.bfloat16), - torch.tensor([0]), - ) - - traced = export(MyModel(), example_inputs).run_decompositions({}) - self.assertExpectedInline( - traced.graph_module.code, - """\ -def forward(self, x, y): - item = torch.ops.aten.item.default(y); y = None - select = torch.ops.aten.select.int(x, 0, item); x = item = None - return (select,)""", - ignore_empty_lines=True, - ) - if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 59c08f71671e0..0f299cd6b6c79 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3529,88 +3529,6 @@ def func(x): ignore_empty_lines=True, ) - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_select_index(self): - cnt = CompileCounterWithBackend("inductor") - - def func(x, y): - u0 = y.item() - return ( - torch.select(x, 0, u0), - torch.select(x, 1, u0), - torch.select(x, 2, u0), - ) - - compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) - x = torch.rand(3, 3, 3) - zero = torch.tensor([0]) - pos = torch.tensor([1]) - # code can handle both negative and positive indices. - neg = torch.tensor([-1]) - - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - self.assertEqual(compiled_func(x, zero), func(x, zero)) - output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() - self.assertExpectedInline( - output, - """\ - _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None - select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) - select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) - select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None - return (select, select_1, select_2)""", # noqa: B950 - ignore_comments=True, - ignore_empty_lines=True, - ) - self.assertEqual(compiled_func(x, pos), func(x, pos)) - self.assertEqual(compiled_func(x, neg), func(x, neg)) - self.assertEqual(cnt.frame_count, 1) - - def func2(x, y): - u0, u1 = y.tolist() - return torch.select(x, 0, u0 + u1) - - compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)( - func2 - ) - zero = torch.tensor([0, 0]) - pos = torch.tensor([1, 1]) - neg = torch.tensor([-1, -1]) - - self.assertEqual(compiled_func2(x, pos), func2(x, pos)) - self.assertEqual(compiled_func2(x, neg), func2(x, neg)) - self.assertEqual(compiled_func2(x, zero), func2(x, zero)) - self.assertEqual(cnt.frame_count, 2) - - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_select_index_with_check(self): - def func3(x, y): - u0 = y.item() - # Test that taking the non-unbacked path works fine also. - torch._check(u0 >= 0) - return (torch.select(x, 1, u0),) - - compiled_func3 = torch.compile( - fullgraph=True, backend="inductor", dynamic=True - )(func3) - x = torch.rand(3, 3, 3) - zero = torch.tensor([0]) - pos = torch.tensor([1]) - print(compiled_func3(x, pos)) - - self.assertEqual(compiled_func3(x, pos), func3(x, pos)) - self.assertEqual(compiled_func3(x, zero), func3(x, zero)) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_unbacked_select_index_cpp_wrapper(self): - self.test_unbacked_select_index() - instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_export/passes/_node_metadata_hook.py b/torch/_export/passes/_node_metadata_hook.py index b1195cf421288..41005e5009738 100644 --- a/torch/_export/passes/_node_metadata_hook.py +++ b/torch/_export/passes/_node_metadata_hook.py @@ -54,7 +54,6 @@ def _node_metadata_hook(node: torch.fx.Node, stack_trace: Optional[str] = None) ) }, ) - node.meta["torch_fn"] = ( f"{node.target.__name__}_0", f"{node.target.__class__.__name__}.{node.target.__name__}", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 8a7f1b2aaa028..cbca6d9fe5d28 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1447,20 +1447,6 @@ def codegen_dynamic_scalar(self, node): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) - def codegen_dynamic_select_index(self, node): - index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) - - index_compute_str = ( - f"{index_cpp_str} < 0 ? {index_cpp_str} + " - f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" - ) - self.writeline( - f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " - f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" - ) - # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again - self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) - def make_buffer_free(self, buffer): return ( "" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e601cbb8ed894..0b8ba86c3c185 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1802,14 +1802,6 @@ def codegen_multi_output(self, node: ir.MultiOutput): arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) - def codegen_dynamic_select_index(self, node): - index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" - self.writeline( - f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" - ) - # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again - self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) - def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index f948a7a534c8f..9de52061c6489 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -11,7 +11,6 @@ import sympy import torch -from torch._inductor.utils import get_free_symbols from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -39,12 +38,6 @@ class Dep(abc.ABC): name: str index: sympy.Expr - @abc.abstractmethod - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - pass - @abc.abstractmethod def rename(self, renames: dict[str, str]) -> Self: pass @@ -77,15 +70,6 @@ class MemoryDep(Dep): size: tuple[sympy.Expr, ...] mode: Optional[str] = None - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return ( - get_free_symbols(self.index, unbacked_only) - | get_free_symbols(self.size, unbacked_only) - | get_free_symbols(self.var_names, unbacked_only) - ) - def __repr__(self) -> str: maybe_mode = "" if self.mode is not None: @@ -323,11 +307,6 @@ def rename(self, renames: dict[str, str]) -> "StarDep": return StarDep(renames[self.name], self.mode) return self - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return OrderedSet() - def numbytes_hint(self) -> int: try: return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( @@ -364,11 +343,6 @@ class WeakDep(Dep): # Buffer that is doing the mutation mutating_buf: str - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return OrderedSet() - @property def index(self) -> sympy.Expr: raise NotImplementedError("WeakDep does not have an index") @@ -466,15 +440,6 @@ def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]: names.add(dep.name) return names - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - result: OrderedSet[sympy.Symbol] = OrderedSet() - - for dep in self.reads_and_writes(): - result |= dep.get_free_symbol_uses(unbacked_only) - return result - class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 660b01b69233b..ac299d5b0c2d0 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -341,7 +341,6 @@ def __init__( shape_env.deferred_runtime_asserts.copy() ) self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() - self.sizevars = SizeVarAllocator(shape_env) self.graph_input_names: list[str] = [] self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} @@ -1822,7 +1821,7 @@ def debug(msg: str) -> None: shape_env = V.graph.sizevars.shape_env - # An input can be unbacked symint i.e.: when mark_unabcked is used. + # An input can an unbacked symint i.e.: when mark_unabcked is used. # in that case add it to new_unbacked_defs. if ( n.op == "placeholder" @@ -1889,7 +1888,6 @@ def format_new_defs() -> str: V.fake_mode.shape_env.unbacked_renamings.get(s, s) for s in unbacked_bindings.keys() ) - assert new_unbacked_defs >= renamed_unbacked_bindings, ( f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" f"fx node is: {n.format_node()}\n" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 25f57a503dfaa..d6dd82aa52f2d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -49,7 +49,6 @@ from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._inductor import metrics -from torch._inductor.utils import get_free_symbols from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -63,6 +62,7 @@ compute_unbacked_bindings, free_symbols, free_unbacked_symbols, + IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -304,6 +304,13 @@ def reindex(index: Sequence[_T]) -> Sequence[_V]: return reindex +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) + + NHWC_STRIDE_ORDER = [3, 0, 2, 1] NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] @@ -4322,13 +4329,6 @@ def get_read_names(self) -> OrderedSet[str]: return self.data.get_read_names() def get_read_writes(self) -> dependencies.ReadWrites: - if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): - return dependencies.ReadWrites( - reads=OrderedSet(), - writes=OrderedSet(), - index_exprs=OrderedSet(), - ) - with patch.object(FlexibleLayout, "allow_indexing", True): if self.data.get_reduction_type(): return extract_read_writes( @@ -4367,7 +4367,6 @@ def get_free_symbol_uses( | get_free_symbols(self.get_stride(), unbacked_only) | get_free_symbols(self.get_offset(), unbacked_only) | self.data.get_free_symbol_uses(unbacked_only) - | self.get_read_writes().get_free_symbol_uses(unbacked_only) ) def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: @@ -6976,50 +6975,6 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) -class DynamicSelectStorageOffset(ExternKernel): - """ - The result of computing a dynamic selection index is determined as follows: when the index in the - select operation is unbacked, the actual index calculation is ambiguous for negative indices - (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked - SymInt to represent the storage offset and decompose the select operation into a call to as_strided, - computing the storage offset at runtime with this node. - """ - - def get_reads(self) -> OrderedSet[Dep]: - return OrderedSet() - - def should_allocate(self) -> bool: - return False - - def __init__( - self, - unbacked_offset_symbol: sympy.Symbol, - index: sympy.Symbol, - base_offset: Union[sympy.Symbol, int], - base_dim_stride: Union[sympy.Symbol, int], - size: Union[sympy.Symbol, int], - ) -> None: - super().__init__(None, NoneLayout(device=torch.device("cpu")), []) - # This node codegen the following: - # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size) - self.unbacked_offset_symbol = unbacked_offset_symbol - self.index = index - self.base_offset = base_offset - self.base_dim_stride = base_dim_stride - self.size = size - - def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: - return OrderedSet([self.unbacked_offset_symbol]) - - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return get_free_symbols(self.index, unbacked_only) - - def codegen(self, wrapper: PythonWrapperCodegen) -> None: - wrapper.codegen_dynamic_select_index(self) - - class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index f6b08499e4d5c..c4c8f70003c60 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,11 +40,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.fx.experimental.symbolic_shapes import ( - free_unbacked_symbols, - has_free_unbacked_symbols, - resolve_unbacked_bindings, -) +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -994,7 +990,10 @@ def squeeze(x, dim=None): new_shape = [] for d, s in enumerate(x.get_size()): - if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): + if not ( + d in dims + and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) + ): new_shape.append(s) # squeeze does nothing if the size isn't 1 @@ -1760,60 +1759,8 @@ def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): @register_lowering(aten.select, type_promotion_kind=None) def select(x, dim, idx): - idx = sympy.expand(idx) - size = sympy.expand(x.get_size()[dim]) - actual_index = None - - if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)): - actual_index = idx + size - elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)): - actual_index = idx - - if actual_index is not None: - if has_free_unbacked_symbols(idx): - # Inductor could generate incorrect views for tensors with unbacked symbols here; - # Squeeze operations are translated to views, resulting in incorrect strides. - # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, - # we use as_strided instead. - # Removing this branch will cause test_unbacked_select_index_with_check to fail. - new_size = x.get_size() - new_stride = x.get_stride() - new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index - - del new_size[dim] - del new_stride[dim] - return as_strided(x, new_size, new_stride, new_storage_offset) - else: - slice_result = slice_(x, dim, actual_index, actual_index + 1) - return squeeze(slice_result, dim) - - # Unbacked Semantics: - # When the index idx is unbacked (e.g., u0), we compute the index dynamically - # during the lowering of the select operation using DynamicSelectStorageOffset. - - unbacked_bindings = resolve_unbacked_bindings( - V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] - ) - assert unbacked_bindings is not None - assert len(unbacked_bindings) == 1, unbacked_bindings - unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) - - new_size = x.get_size() - new_stride = x.get_stride() - new_storage_offset = unbacked_offset_sym - buffer = ir.DynamicSelectStorageOffset( - unbacked_offset_sym, - idx, - x.get_layout().offset, - new_stride[dim], - x.get_size()[dim], - ) - buffer.name = V.graph.register_buffer(buffer) - V.graph.register_operation(buffer) - - del new_size[dim] - del new_stride[dim] - return as_strided(x, new_size, new_stride, new_storage_offset) + idx = View.handle_negative_index(idx, x.get_size()[dim]) + return squeeze(slice_(x, dim, idx, idx + 1), dim) @register_lowering(aten.split, type_promotion_kind=None) @@ -3139,6 +3086,8 @@ def long_tensor(data): @register_lowering(aten._local_scalar_dense) def _local_scalar_dense(data): + from torch.fx.experimental.symbolic_shapes import resolve_unbacked_bindings + # This is interesting! Most lowerings return tensors, so you can just # return the buffer you allocated and it will get used (or not used, if # it's dead.) But _local_scalar_dense (aka item) returns an int, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index a4507990400fd..34f15869085f0 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2130,11 +2130,9 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.logged_slow_fusion = OrderedSet[tuple[str, str]]() if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) - self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) - self.merge_loops() self.finalize_multi_template_buffers() if config.combo_kernels: @@ -2368,6 +2366,7 @@ def add_user( for node in self.nodes: log.debug("scheduling %s", node.node) + # unbacked symbols don't follow ordinary buffer dependencies, so # we track their def/uses separately assert node.node is not None diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 7b3f495382f76..5f9ce0b814eba 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -69,20 +69,13 @@ "inductor_autotune_lookup_table", ] -from torch.fx.experimental.symbolic_shapes import ( - free_symbols, - free_unbacked_symbols, - IterateExprs, - ShapeEnv, -) - - if TYPE_CHECKING: from collections.abc import Iterable, Sequence, ValuesView from torch import SymBool, SymFloat, SymInt from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.fx import GraphModule + from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.node import Node from .codegen.common import WorkspaceArg @@ -3366,10 +3359,3 @@ def aoti_model_name_from_config() -> str: model_name = config.aot_inductor.model_name_for_generated_files model_name = "aoti_model" if model_name is None else model_name return model_name - - -def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: - if unbacked_only: - return free_unbacked_symbols(x) - else: - return free_symbols(x) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2933a37c37fd8..ae87e0e17fb37 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5553,6 +5553,39 @@ def meta_zeros( ) +@register_meta(aten.select.int) +def meta_select(self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + torch._check_index( + not ( + guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) + ), + lambda: f"select(): index {index} out of range for tensor of size " + f"{self.size()} at dimension {dim}", + ) + + index = index if index >= 0 else index + size + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = self.storage_offset() + index * new_stride[dim] + del new_size[dim] + del new_stride[dim] + + return self.as_strided(new_size, new_stride, new_storage_offset) + + @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): return utils.clone_preserve_strides(self) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index e2e24cb59bc27..e802d9a4389d4 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -359,48 +359,6 @@ def unique2( return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) -@register_op_impl(aten.select.int) -def meta_select(fake_mode, func, self, dim, index): - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if self.is_sparse: - return NotImplemented - - ndim = self.dim() - torch._check_index( - ndim != 0, - lambda: "select() cannot be applied to a 0-dim tensor.", - ) - - dim = dim if dim >= 0 else dim + ndim - size = self.size(dim) - - new_size = list(self.size()) - new_stride = list(self.stride()) - - new_storage_offset = None - if guard_or_false(index >= 0): - new_storage_offset = self.storage_offset() + index * new_stride[dim] - elif guard_or_false(index < 0): - new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim] - - if new_storage_offset is None: - if fake_mode.shape_env is None or ( - not fake_mode.shape_env.allow_scalar_outputs - and not fake_mode.allow_scalar_outputs - ): - raise DataDependentOutputException(func) - - # index is data-dependent, we do not know which index we are accessing it could be index or index+size! - # we assign a new data-dependent symbol for the storage offset. - new_storage_offset = fake_mode.shape_env.create_unbacked_symint() - - del new_size[dim] - del new_stride[dim] - assert new_storage_offset is not None - return self.as_strided(new_size, new_stride, new_storage_offset) - - @register_op_impl(aten.unique_dim.default) def unique_dim( fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 4814e2daefe33..e38e5f777d669 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1282,7 +1282,6 @@ def compute_unbacked_bindings( return None fs = shape_env.pending_fresh_unbacked_symbols - pending = set(fs) if not pending: return None @@ -4810,7 +4809,6 @@ def create_unbacked_symfloat(self) -> SymFloat: ) self.counter["create_unbacked_symbol"] += 1 if not self._ignore_fresh_unbacked_symbols_tls(): - print(f"adding {symbol}") self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index bb71a25971da7..38c64c527aff0 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -461,7 +461,6 @@ def go(node, keypath): ), keypath[2:], ) - return go( graph.call_method( keypath[0].name, (node, keypath[1].idx) @@ -469,15 +468,6 @@ def go(node, keypath): keypath[2:], ) elif isinstance(keypath[0], CallMethodKey): - if keypath[0].name == "storage_offset": - return go( - graph.call_function( - torch.ops.aten.sym_storage_offset.default, - (node,), - ), - keypath[1:], - ) - return go( graph.call_method(keypath[0].name, (node,)), keypath[1:] ) From 94d7f0c1ef9a4cb4db0eb5d6b1ffc55941cbeab1 Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 17 Jul 2025 16:50:01 +0000 Subject: [PATCH 192/457] Cleanup old caffe2 scripts (#158475) Testing on this one is grep based: if there were no reference to that script I can find, I deleted. We can easily add any of these back if needed! Pull Request resolved: https://github.com/pytorch/pytorch/pull/158475 Approved by: https://github.com/seemethere, https://github.com/huydhn, https://github.com/cyyever --- .github/workflows/pull.yml | 15 -- scripts/README.md | 39 ----- scripts/add_apache_header.sh | 1 - scripts/apache_header.txt | 15 -- scripts/apache_python.txt | 14 -- scripts/build_android.sh | 189 ----------------------- scripts/build_android_gradle.sh | 102 ------------ scripts/build_host_protoc.sh | 59 ------- scripts/build_ios.sh | 155 ------------------- scripts/build_local.sh | 82 ---------- scripts/build_mobile.sh | 107 ------------- scripts/build_pytorch_android.sh | 51 ------ scripts/build_raspbian.sh | 44 ------ scripts/build_tegra_x1.sh | 51 ------ scripts/build_tizen.sh | 118 -------------- scripts/build_windows.bat | 80 ---------- scripts/diagnose_protobuf.py | 92 ----------- scripts/fbcode-dev-setup/ccache_setup.sh | 92 ----------- scripts/get_python_cmake_flags.py | 24 --- scripts/proto.ps1 | 18 --- scripts/remove_apache_header.sh | 13 -- scripts/temp.sh | 7 - scripts/xcode_build.rb | 76 --------- 23 files changed, 1444 deletions(-) delete mode 100755 scripts/add_apache_header.sh delete mode 100644 scripts/apache_header.txt delete mode 100644 scripts/apache_python.txt delete mode 100755 scripts/build_android.sh delete mode 100755 scripts/build_android_gradle.sh delete mode 100755 scripts/build_host_protoc.sh delete mode 100755 scripts/build_ios.sh delete mode 100755 scripts/build_local.sh delete mode 100755 scripts/build_mobile.sh delete mode 100755 scripts/build_pytorch_android.sh delete mode 100755 scripts/build_raspbian.sh delete mode 100755 scripts/build_tegra_x1.sh delete mode 100755 scripts/build_tizen.sh delete mode 100644 scripts/build_windows.bat delete mode 100644 scripts/diagnose_protobuf.py delete mode 100755 scripts/fbcode-dev-setup/ccache_setup.sh delete mode 100644 scripts/get_python_cmake_flags.py delete mode 100644 scripts/proto.ps1 delete mode 100755 scripts/remove_apache_header.sh delete mode 100755 scripts/temp.sh delete mode 100644 scripts/xcode_build.rb diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 59a7265173800..be0bdc527cc11 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -315,21 +315,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3-clang18-mobile-build: - name: linux-jammy-py3-clang18-mobile-build - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3-clang12-mobile-build - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan - build-generates-artifacts: false - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml diff --git a/scripts/README.md b/scripts/README.md index a1c5ae5f93e67..367e7261f6a60 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1,40 +1 @@ This directory contains the useful tools. - - -## build_android.sh -This script is to build PyTorch/Caffe2 library for Android. Take the following steps to start the build: - -- set ANDROID_NDK to the location of ndk - -```bash -export ANDROID_NDK=YOUR_NDK_PATH -``` - -- run build_android.sh -```bash -#in your PyTorch root directory -bash scripts/build_android.sh -``` -If succeeded, the libraries and headers would be generated to build_android/install directory. You can then copy these files from build_android/install to your Android project for further usage. - -You can also override the cmake flags via command line, e.g., following command will also compile the executable binary files: -```bash -bash scripts/build_android.sh -DBUILD_BINARY=ON -``` - -## build_ios.sh -This script is to build PyTorch/Caffe2 library for iOS, and can only be performed on macOS. Take the following steps to start the build: - -- Install Xcode from App Store, and configure "Command Line Tools" properly on Xcode. -- Install the dependencies: - -```bash -brew install cmake automake libtool -``` - -- run build_ios.sh -```bash -#in your PyTorch root directory -bash scripts/build_ios.sh -``` -If succeeded, the libraries and headers would be generated to build_ios/install directory. You can then copy these files to your Xcode project for further usage. diff --git a/scripts/add_apache_header.sh b/scripts/add_apache_header.sh deleted file mode 100755 index a29a059d2d033..0000000000000 --- a/scripts/add_apache_header.sh +++ /dev/null @@ -1 +0,0 @@ -cat apache_header.txt $1 > _add_apache_header.txt && mv _add_apache_header.txt $1 diff --git a/scripts/apache_header.txt b/scripts/apache_header.txt deleted file mode 100644 index b4eff258eb04d..0000000000000 --- a/scripts/apache_header.txt +++ /dev/null @@ -1,15 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ diff --git a/scripts/apache_python.txt b/scripts/apache_python.txt deleted file mode 100644 index bc104d8845154..0000000000000 --- a/scripts/apache_python.txt +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2016-present, Facebook, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -############################################################################## diff --git a/scripts/build_android.sh b/scripts/build_android.sh deleted file mode 100755 index 43f11b86828d4..0000000000000 --- a/scripts/build_android.sh +++ /dev/null @@ -1,189 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the android target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the Android platform -# using android-cmake. A few notes: -# -# (1) This build also does a host build for protobuf. You will need autoconf -# to carry out this. If autoconf is not possible, you will need to provide -# a pre-built protoc binary that is the same version as the protobuf -# version under third_party. -# If you are building on Mac, you might need to install autotool and -# libtool. The easiest way is via homebrew: -# brew install automake -# brew install libtool -# (2) You will need to have android ndk installed. The current script assumes -# that you set ANDROID_NDK to the location of ndk. -# (3) The toolchain and the build target platform can be specified with the -# cmake arguments below. For more details, check out android-cmake's doc. - -set -e - -# Android specific flags -if [ -z "$ANDROID_ABI" ]; then - ANDROID_ABI="armeabi-v7a with NEON" -fi -ANDROID_NATIVE_API_LEVEL="21" -echo "Build with ANDROID_ABI[$ANDROID_ABI], ANDROID_NATIVE_API_LEVEL[$ANDROID_NATIVE_API_LEVEL]" - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" -if [ -z "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not set; please set it to the Android NDK directory" - exit 1 -fi - -if [ ! -d "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?" - exit 1 -fi - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -ANDROID_NDK_PROPERTIES="$ANDROID_NDK/source.properties" -[ -f "$ANDROID_NDK_PROPERTIES" ] && ANDROID_NDK_VERSION=$(sed -n 's/^Pkg.Revision[^=]*= *\([0-9]*\)\..*$/\1/p' "$ANDROID_NDK_PROPERTIES") - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" -echo "Using Android NDK at $ANDROID_NDK" -echo "Android NDK version: $ANDROID_NDK_VERSION" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use android-cmake to build Android project from CMake. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") - -if [ -z "$BUILD_MOBILE_BENCHMARK" ]; then - BUILD_MOBILE_BENCHMARK=0 -fi - -if [ -z "$BUILD_MOBILE_TEST" ]; then - BUILD_MOBILE_TEST=0 -fi -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 0, build full jit interpreter. -# Default behavior is to build lite interpreter -# cmd: BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK") -CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") -if (( "${ANDROID_NDK_VERSION:-0}" < 18 )); then - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=gcc") -else - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=clang") -fi -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Android specific flags -CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") -CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") -CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") -CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") -if [ "${ANDROID_STL_SHARED:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_STL=c++_shared") -fi -if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") -fi - -if [ -n "${USE_VULKAN}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN=ON") - if [ -n "${USE_VULKAN_FP16_INFERENCE}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_FP16_INFERENCE=ON") - fi - if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") - fi -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=($@) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" third_party/pocketfft/pocketfft_hdronly.h -fi - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further Android project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Android project directory." diff --git a/scripts/build_android_gradle.sh b/scripts/build_android_gradle.sh deleted file mode 100755 index fc27c5dd2516b..0000000000000 --- a/scripts/build_android_gradle.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env bash -set -eux -o pipefail - -env -echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" - -export ANDROID_NDK_HOME=/opt/ndk -export ANDROID_NDK=/opt/ndk -export ANDROID_HOME=/opt/android/sdk - -# Must be in sync with GRADLE_VERSION in docker image for android -# https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh#L155 -export GRADLE_VERSION=6.8.3 -export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION -export GRADLE_PATH=$GRADLE_HOME/bin/gradle - -# touch gradle cache files to prevent expiration -while IFS= read -r -d '' file -do - touch "$file" || true -done < <(find /var/lib/jenkins/.gradle -type f -print0) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h -fi - -export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties -rm -f $GRADLE_LOCAL_PROPERTIES -echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES -echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES -echo "cmake.dir=/usr/local" >> $GRADLE_LOCAL_PROPERTIES - -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Run custom build script -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-custom-build* ]]; then - # Install torch & torchvision - used to download & dump used ops from test model. - retry pip install torch torchvision --progress-bar off - - exec "$(dirname "${BASH_SOURCE[0]}")/../android/build_test_app_custom.sh" armeabi-v7a -fi - -# Run default build -BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include -BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib - -BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include -BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include -BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include -BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib - -PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main - -JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include -mkdir -p $JNI_INCLUDE_DIR - -JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs -mkdir -p $JNI_LIBS_DIR - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 -ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 - -if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 -ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a -fi - -GRADLE_PARAMS="-p android assembleRelease --debug --stacktrace" -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then - GRADLE_PARAMS+=" -PABI_FILTERS=x86" -fi - -if [ -n "${GRADLE_OFFLINE:-}" ]; then - GRADLE_PARAMS+=" --offline" -fi - -$GRADLE_PATH $GRADLE_PARAMS - -find . -type f -name "*.a" -exec ls -lh {} \; - -while IFS= read -r -d '' file -do - echo - echo "$file" - ls -lah "$file" - zipinfo -l "$file" -done < <(find . -type f -name '*.aar' -print0) - -find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz diff --git a/scripts/build_host_protoc.sh b/scripts/build_host_protoc.sh deleted file mode 100755 index cd37db3b31713..0000000000000 --- a/scripts/build_host_protoc.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -############################################################################## -# Build script to build the protoc compiler for the host platform. -############################################################################## -# This script builds the protoc compiler for the host platform, which is needed -# for any cross-compilation as we will need to convert the protobuf source -# files to cc files. -# -# --other-flags accepts flags that should be passed to cmake. Optional. -# -# After the execution of the file, one should be able to find the host protoc -# binary at build_host_protoc/bin/protoc. - -set -e - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_host_protoc"} -mkdir -p $BUILD_ROOT/build -cd $BUILD_ROOT/build - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT") -CMAKE_ARGS+=("-Dprotobuf_BUILD_TESTS=OFF") - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -while true; do - case "$1" in - --other-flags) - shift; - CMAKE_ARGS+=("$@") - break ;; - "") - break ;; - *) - echo "Unknown option passed as argument: $1" - break ;; - esac -done - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ] && [ -d /usr/local/opt/ccache/libexec ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=/usr/local/opt/ccache/libexec/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=/usr/local/opt/ccache/libexec/g++") -fi - -cmake "$CAFFE2_ROOT/third_party/protobuf/cmake" ${CMAKE_ARGS[@]} - -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi -cmake --build . -- "-j${MAX_JOBS}" install diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh deleted file mode 100755 index ad16cb940dcb8..0000000000000 --- a/scripts/build_ios.sh +++ /dev/null @@ -1,155 +0,0 @@ -#!/bin/bash -xe -############################################################################## -# Example command to build the iOS target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the iOS platform -# using ios-cmake. This is very similar to the android-cmake - see -# build_android.sh for more details. - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# bitcode -if [ "${ENABLE_BITCODE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") -fi - -# Use ios-cmake to build iOS project from CMake. -# This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and -# CMAKE_CXX_COMPILER to /usr/bin/g++. In order to use ccache (if it is available) we -# must override these variables via CMake arguments. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$CAFFE2_ROOT/cmake/iOS.cmake") -if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec -fi -if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") -fi - -# IOS_PLATFORM controls type of iOS platform (see ios-cmake) -if [ -n "${IOS_PLATFORM:-}" ]; then - CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") - if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then - # enable bitcode by default for watchos - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") - # disable the QNNPACK - CMAKE_ARGS+=("-DUSE_PYTORCH_QNNPACK=OFF") - fi -else - # IOS_PLATFORM is not set, default to OS, which builds iOS. - CMAKE_ARGS+=("-DIOS_PLATFORM=OS") -fi - -if [ -n "${IOS_ARCH:-}" ]; then - CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") -fi - -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") - -# Don't build binaries or tests (only the library) -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") - -# Metal -if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") -fi - -# Core ML -if [ "${USE_COREML_DELEGATE}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_COREML_DELEGATE=ON") -fi - -# pthreads -CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") -CMAKE_ARGS+=("-DCMAKE_HAVE_THREADS_LIBRARY=1") -CMAKE_ARGS+=("-DCMAKE_USE_PTHREADS_INIT=1") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# enable ARC -CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fobjc-arc") - -# Now, actually build the iOS target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=MinSizeRel \ - -DBUILD_SHARED_LIBS=OFF \ - ${CMAKE_ARGS[@]} \ - $@ - -cmake --build . -- "-j$(sysctl -n hw.ncpu)" - -# copy headers and libs to install directory -echo "Will install headers and libs to $INSTALL_PREFIX for further Xcode project usage." -make install -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Xcode project directory." diff --git a/scripts/build_local.sh b/scripts/build_local.sh deleted file mode 100755 index b843671501256..0000000000000 --- a/scripts/build_local.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -# -############################################################################## -# Example command to build Caffe2 -############################################################################## -# - -set -ex - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -CMAKE_ARGS=() - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ]; then - if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec - fi - if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") - fi -fi - -# Use special install script with Anaconda -if [ -n "${USE_ANACONDA}" ]; then - export SKIP_CONDA_TESTS=1 - export CONDA_INSTALL_LOCALLY=1 - "${ROOT_DIR}/scripts/build_anaconda.sh" "$@" -else - # Make sure that pyyaml is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import yaml' 2>&1)" ]]; then - echo "Installing pyyaml with pip at $(which pip)" - pip install --user pyyaml - fi - - # Make sure that typing is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import typing' 2>&1)" ]]; then - echo "Installing typing with pip at $(which pip)" - pip install --user typing - fi - - # Build protobuf compiler from third_party if configured to do so - if [ -n "${USE_HOST_PROTOC:-}" ]; then - echo "USE_HOST_PROTOC is set; building protoc before building Caffe2..." - "$CAFFE2_ROOT/scripts/build_host_protoc.sh" - CUSTOM_PROTOC_EXECUTABLE="$CAFFE2_ROOT/build_host_protoc/bin/protoc" - echo "Built protoc $("$CUSTOM_PROTOC_EXECUTABLE" --version)" - CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CUSTOM_PROTOC_EXECUTABLE") - fi - - # We are going to build the target into build. - BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} - mkdir -p "$BUILD_ROOT" - cd "$BUILD_ROOT" - echo "Building Caffe2 in: $BUILD_ROOT" - - cmake "$CAFFE2_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" \ - "$@" - - # Determine the number of CPUs to build with. - # If the `CAFFE_MAKE_NCPUS` variable is not specified, use them all. - if [ -n "${MAX_JOBS}" ]; then - CAFFE_MAKE_NCPUS="$MAX_JOBS" - elif [ -n "${CAFFE_MAKE_NCPUS}" ]; then - CAFFE_MAKE_NCPUS="$CAFFE_MAKE_NCPUS" - elif [ "$(uname)" == 'Darwin' ]; then - CAFFE_MAKE_NCPUS="$(sysctl -n hw.ncpu)" - else - CAFFE_MAKE_NCPUS="$(nproc)" - fi - - # Now, actually build the target. - cmake --build . -- "-j$CAFFE_MAKE_NCPUS" -fi diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh deleted file mode 100755 index 7b1995a61ebc7..0000000000000 --- a/scripts/build_mobile.sh +++ /dev/null @@ -1,107 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the mobile target. -############################################################################## -# -# This script shows how one can build a libtorch library optimized for mobile -# devices using host toolchain. - -set -e - -export BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 1, build lite interpreter. -# Default behavior is to build full jit interpreter. -# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_mobile.sh -if [ "x${BUILD_LITE_INTERPRETER}" == "x1" ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -fi -if [ "x${TRACING_BASED}" == "x1" ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi - -# Lightweight dispatch bypasses the PyTorch Dispatcher. -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_ROCM=OFF") -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_BLAS=OFF") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=("$@") - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/build_pytorch_android.sh b/scripts/build_pytorch_android.sh deleted file mode 100755 index 7b80965e34b5c..0000000000000 --- a/scripts/build_pytorch_android.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -set -eux - -############################################################################## -# Master script to build PyTorch Android library with Java bindings. -############################################################################## -# Example usage: -# - Build default AARs: -# scripts/build_pytorch_android.sh -# -# - Build for specific ABI(s): -# scripts/build_pytorch_android.sh armeabi-v7a -# scripts/build_pytorch_android.sh arm64-v8a,x86,x86_64 -# -# Script's workflow: -# 1. Builds libtorch for android for specified android abisi (by default for all 4). -# Custom list of android abis can be specified as a bash argument as comma separated list. -# For example just for testing on android x86 emulator we need only x86 build. -# ./scripts/build_pytorch_android.sh x86 -# 2. Creates symbolic links to android/pytorch_android/src/main/jniLibs/${abi} for libtorch build output, -# android/pytorch_android/src/main/cpp/libtorch_include/${abi} for headers. -# 3. Runs pyotrch_android gradle build: -# gradle assembleRelease - -PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" -PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android - -echo "PYTORCH_DIR:$PYTORCH_DIR" - -source "$PYTORCH_ANDROID_DIR/common.sh" - -check_android_sdk -check_gradle -parse_abis_list "$@" -build_android - -# To set proxy for gradle add following lines to ./gradle/gradle.properties: -# systemProp.http.proxyHost=... -# systemProp.http.proxyPort=8080 -# systemProp.https.proxyHost=... -# systemProp.https.proxyPort=8080 - -if [ "$CUSTOM_ABIS_LIST" = true ]; then - # Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems - # with it when abiFilters are specified. - $GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR assembleRelease -else - $GRADLE_PATH -p $PYTORCH_ANDROID_DIR clean assembleRelease -fi - -find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah diff --git a/scripts/build_raspbian.sh b/scripts/build_raspbian.sh deleted file mode 100755 index b1fe85926219e..0000000000000 --- a/scripts/build_raspbian.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the Raspbian target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for raspbian. The build -# is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain dependencies. -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - libpython-dev \ - python-pip \ - python-numpy \ - protobuf-compiler \ - python-protobuf -# python dependencies -sudo pip install hypothesis - -# Now, actually build the raspbian target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: you can add more dependencies above if you need libraries such as -# leveldb, lmdb, etc. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=hard" \ - || exit 1 - -# Note: while Raspberry pi has 4 cores, running too many builds in parallel may -# cause out of memory errors so we will simply run -j 2 only. -make -j 2 || exit 1 diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh deleted file mode 100755 index 063e17dfe3514..0000000000000 --- a/scripts/build_tegra_x1.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build Caffe2 on Tegra X1. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for NVidia's TX1. -# The build script assumes that you have the most recent libraries installed -# via the JetPack toolkit available at -# https://developer.nvidia.com/embedded/jetpack -# and it assumes that we are starting from a fresh system after the jetpack -# installation. If you have already installed some of the dependencies, you -# may be able to skip quite a few of the apt-get installs. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain necessary dependencies -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo apt-get install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by apt-get is quite old so we install it via pip -sudo pip install hypothesis - -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# CUDA_USE_STATIC_CUDA_RUNTIME needs to be set to off so that opencv can be -# properly used. Otherwise, opencv will complain that opencv_dep_cudart cannot -# be found. -cmake "$CAFFE2_ROOT" -DCUDA_USE_STATIC_CUDA_RUNTIME=OFF \ - || exit 1 - -make -j 4 || exit 1 diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh deleted file mode 100755 index 2262a2503c1d0..0000000000000 --- a/scripts/build_tizen.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env bash -############################################################################## -# Example command to build the Tizen target (RPi3). -############################################################################## -# -# This script shows how one can build a Caffe2 binary for a Tizen device (RPi3). -# The build is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -setup_environment(){ -# The rootfs image for a Tizen target (RPi3)is located at the below webpage: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/tizen-unified_20170529.1/images/ -# If you do not have a Tizen device, Please, run qemu-arm-static and chroot command. -# $ sudo chroot ~/tizen-rootfs qemu-arm-static /usr/bin/bash - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 Tizen into: $BUILD_ROOT" -} - -caffe2_lite_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - make \ - strace \ - cmake \ - gcc* \ - binutils \ - glibc* \ - cpp \ - protobuf-devel \ - libstdc++* -} - -caffe2_lite_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake .. \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_GFLAGS=OFF \ - -DUSE_GLOG=OFF -DUSE_NNPACK=OFF \ - -DRUN_HAVE_STD_REGEX=0 \ - -DRUN_HAVE_POSIX_REGEX=0 \ - -DHAVE_GNU_POSIX_REGEX=0 \ - -DUSE_MPI=OFF -DUSE_OPENMP=OFF \ - -DBUILD_PYTHON=OFF \ - -DUSE_GLOO=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -caffe2_full_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# Obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo zypper install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by zypper is quite old so we install it via pip -sudo pip install hypothesis -} - -caffe2_full_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_CUDA=OFF \ - -DUSE_ITT=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -#### Main -# Setup a build environment to compile Caffe2 deeplearning framework in Tizen platform. -setup_environment -# There are two build options to support 'full' version and 'lite' version (by default). -caffe2_lite_dep_packages -caffe2_lite_build diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat deleted file mode 100644 index 60bfebad08c01..0000000000000 --- a/scripts/build_windows.bat +++ /dev/null @@ -1,80 +0,0 @@ -:: ############################################################################# -:: Example command to build on Windows. -:: ############################################################################# - -:: This script shows how one can build a Caffe2 binary for windows. - -@echo off -setlocal - -SET ORIGINAL_DIR=%cd% -SET CAFFE2_ROOT=%~dp0%.. - -if NOT DEFINED BUILD_BINARY ( - set BUILD_BINARY=OFF -) - -if NOT DEFINED BUILD_SHARED_LIBS ( - :: On CI, we test with BUILD_SHARED_LIBS=OFF. - :: By default, it will be BUILD_SHARED_LIBS=ON. - if NOT DEFINED BUILD_ENVIRONMENT ( - set BUILD_SHARED_LIBS=OFF - ) -) - -if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( - set CAFFE2_STATIC_LINK_CUDA=OFF -) - -if NOT DEFINED CMAKE_BUILD_TYPE ( - set CMAKE_BUILD_TYPE=Release -) - -if NOT DEFINED ONNX_NAMESPACE ( - set ONNX_NAMESPACE=onnx_c2 -) - -if NOT DEFINED TORCH_CUDA_ARCH_LIST ( - set TORCH_CUDA_ARCH_LIST=5.0 -) - -if NOT DEFINED USE_CUDA ( - set USE_CUDA=OFF -) - -if NOT DEFINED USE_OBSERVERS ( - set USE_OBSERVERS=OFF -) - -if NOT DEFINED MSVC_Z7_OVERRIDE ( - set MSVC_Z7_OVERRIDE=OFF -) - -if NOT DEFINED CMAKE_GENERATOR ( - set CMAKE_GENERATOR=Ninja -) - -set CMAKE_VERBOSE_MAKEFILE=1 - -:: Install pyyaml for Aten codegen -pip install pyyaml ninja - -echo CAFFE2_ROOT=%CAFFE2_ROOT% -echo CMAKE_GENERATOR=%CMAKE_GENERATOR% -echo CMAKE_BUILD_TYPE=%CMAKE_BUILD_TYPE% - -:: Set up cmake. We will skip building the test files right now. -pushd %CAFFE2_ROOT% -python tools\build_libtorch.py || goto :label_error -popd - -echo "Caffe2 built successfully" -cd %ORIGINAL_DIR% -endlocal -exit /b 0 - -:label_error -echo "Caffe2 building failed" -cd %ORIGINAL_DIR% -endlocal -exit /b 1 diff --git a/scripts/diagnose_protobuf.py b/scripts/diagnose_protobuf.py deleted file mode 100644 index 65af4618228db..0000000000000 --- a/scripts/diagnose_protobuf.py +++ /dev/null @@ -1,92 +0,0 @@ -## @package diagnose_protobuf -# Module scripts.diagnose_protobuf -"""Diagnoses the current protobuf situation. - -Protocol buffer needs to be properly installed for Caffe2 to work, and -sometimes it is rather tricky. Specifically, we will need to have a -consistent version between C++ and python simultaneously. This is a -convenience script for one to quickly check if this is so on one's local -machine. - -Usage: - [set your environmental variables like PATH and PYTHONPATH] - python scripts/diagnose_protobuf.py -""" - -import os -import re -from subprocess import PIPE, Popen - - -# Get python protobuf version. -try: - import google.protobuf - - python_version = google.protobuf.__version__ - python_protobuf_installed = True -except ImportError: - print("DEBUG: cannot find python protobuf install.") - python_protobuf_installed = False - -if os.name == "nt": - protoc_name = "protoc.exe" -else: - protoc_name = "protoc" - -try: - p = Popen([protoc_name, "--version"], stdout=PIPE, stderr=PIPE) - out, err = p.communicate() -except: - print("DEBUG: did not find protoc binary.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False -else: - if p.returncode: - print("DEBUG: protoc returned a non-zero return code.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False - else: - tmp = re.search(r"\d\.\d\.\d", out) - if tmp: - native_version = tmp.group(0) - native_protobuf_installed = True - else: - print("DEBUG: cannot parse protoc version string.") - print("DEBUG: out: " + out) - native_protobuf_installed = False - -PYTHON_PROTOBUF_NOT_INSTALLED = """ -You have not installed python protobuf. Protobuf is needed to run caffe2. You -can install protobuf via pip or conda (if you are using anaconda python). -""" - -NATIVE_PROTOBUF_NOT_INSTALLED = """ -You have not installed the protoc binary. Protoc is needed to compile Caffe2 -protobuf source files. Depending on the platform you are on, you can install -protobuf via: - (1) Mac: using homebrew and do brew install protobuf. - (2) Linux: use apt and do apt-get install libprotobuf-dev - (3) Windows: install from source, or from the releases here: - https://github.com/google/protobuf/releases/ -""" - -VERSION_MISMATCH = f""" -Your python protobuf is of version {python_version} but your native protoc version is of -version {native_version}. This will cause the installation to produce incompatible -protobuf files. This is bad in general - consider installing the same version. -""" - -# Now, give actual recommendations -if not python_protobuf_installed: - print(PYTHON_PROTOBUF_NOT_INSTALLED) - -if not native_protobuf_installed: - print(NATIVE_PROTOBUF_NOT_INSTALLED) - -if python_protobuf_installed and native_protobuf_installed: - if python_version != native_version: - print(VERSION_MISMATCH) - else: - print("All looks good.") diff --git a/scripts/fbcode-dev-setup/ccache_setup.sh b/scripts/fbcode-dev-setup/ccache_setup.sh deleted file mode 100755 index cb461bee2dd27..0000000000000 --- a/scripts/fbcode-dev-setup/ccache_setup.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash - -# This script installs CCache with CUDA support. -# Example usage: -# ./ccache_setup.sh --path /installed/folder - -set -e -shopt -s expand_aliases - -# Setup the proxy -alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" - -# Parse options -path="$HOME/ccache" -force=false - -while [[ $# -gt 0 ]]; do - case "$1" in - --path) - shift - path="$1" - path=$(realpath "$path") - ;; - --force) # Force install - force=true - ;; - --help) - echo 'usage: ./ccache_setup.py --path /installed/folder [--force]' - exit 0 - ;; - *) - echo "Invalid option: $1" - exit 1 - ;; - esac - shift -done - -# Check whether you put nvcc in PATH -set +e -nvcc_path=$(which nvcc) -if [[ -z "$nvcc_path" ]]; then - nvcc_path="/usr/local/cuda/bin/nvcc" - export PATH="/usr/local/cuda/bin:$PATH" -fi -set -e -if [ ! -f "$nvcc_path" ] && ! $force; then - # shellcheck disable=SC2016 - echo 'nvcc is not detected in $PATH' - exit 1 -fi -echo "nvcc is detected at $nvcc_path" - -if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule - if $CUDA_NVCC_EXECUTABLE --version; then - if ! $force; then - echo "CCache with nvcc support is already installed at $CUDA_NVCC_EXECUTABLE, please add --force" - exit 0 - fi - fi -fi - -# Installing CCache -echo "CCache will be installed at $path" -if [ -e "$path" ]; then - mv --backup=t -T "$path" "${path}.old" -fi - -with_proxy git clone https://github.com/colesbury/ccache.git "$path" -b ccbin -cd "$path" -./autogen.sh -./configure -make install prefix="$path" - -mkdir -p "$path/lib" -mkdir -p "$path/cuda" -ln -sf "$path/bin/ccache" "$path/lib/cc" -ln -sf "$path/bin/ccache" "$path/lib/c++" -ln -sf "$path/bin/ccache" "$path/lib/gcc" -ln -sf "$path/bin/ccache" "$path/lib/g++" -ln -sf "$path/bin/ccache" "$path/cuda/nvcc" -"$path/bin/ccache" -M 25Gi - -# Make sure the nvcc wrapped in CCache is runnable -"$path/cuda/nvcc" --version -echo 'Congrats! The CCache with nvcc support is installed!' -echo -e "Please add the following lines to your bash init script:\\n" -echo "################ Env Var for CCache with CUDA support ################" -# shellcheck disable=SC2016 -echo 'export PATH="'"$path"'/lib:$PATH"' -echo 'export CUDA_NVCC_EXECUTABLE="'"$path"'/cuda/nvcc"' -echo '######################################################################' diff --git a/scripts/get_python_cmake_flags.py b/scripts/get_python_cmake_flags.py deleted file mode 100644 index a49debcc884ad..0000000000000 --- a/scripts/get_python_cmake_flags.py +++ /dev/null @@ -1,24 +0,0 @@ -## @package get_python_cmake_flags -# Module scripts.get_python_cmake_flags -############################################################################## -# Use this script to find your preferred python installation. -############################################################################## -# -# You can use the following to build with your preferred version of python -# if your installation is not being properly detected by CMake. -# -# mkdir -p build && cd build -# cmake $(python ../scripts/get_python_cmake_flags.py) .. -# make -# - - -import sys -import sysconfig - - -flags = [ - f"-DPython_EXECUTABLE:FILEPATH={sys.executable}", -] - -print(" ".join(flags), end="") diff --git a/scripts/proto.ps1 b/scripts/proto.ps1 deleted file mode 100644 index a6bce82ff682d..0000000000000 --- a/scripts/proto.ps1 +++ /dev/null @@ -1,18 +0,0 @@ -param( - [string]$protoc, - [string]$srcdir, - [string]$unprocessed, - [string]$processed, - [string]$out -) -$ErrorActionPreference = "Stop" -Get-Content $unprocessed | % {$_ -Replace "caffe2/proto/caffe2.proto", "caffe2.proto"} | Set-Content $processed -Add-Content -Path $processed -Value "option optimize_for = LITE_RUNTIME;`n" -NoNewline -$dir = (Get-Item $processed).DirectoryName - -copy $srcdir/caffe2/proto/caffe2.proto $srcdir/caffe2.proto -Add-Content -Path $srcdir/caffe2.proto -Value "option optimize_for = LITE_RUNTIME;`n" -NoNewline - -$processed = (Get-Item $processed).Name -$cmd = "$protoc -I${dir} --cpp_out=$out $processed" -Invoke-Expression $cmd diff --git a/scripts/remove_apache_header.sh b/scripts/remove_apache_header.sh deleted file mode 100755 index 97980bfbb0ef6..0000000000000 --- a/scripts/remove_apache_header.sh +++ /dev/null @@ -1,13 +0,0 @@ -if [[ "$1" == *.py ]]; then - apache_header="apache_python.txt" -else - apache_header="apache_header.txt" -fi -apache_lines=$(wc -l < "${apache_header}") -apache_md5=$(cat "${apache_header}" | md5) -header_md5=$(head -n ${apache_lines} $1 | md5) -if [ "${header_md5}" == "${apache_md5}" ]; then - keep_lines=$(($(wc -l < $1) - ${apache_lines})) - tail -n ${keep_lines} $1 > _remove_apache_header.txt - mv _remove_apache_header.txt $1 -fi diff --git a/scripts/temp.sh b/scripts/temp.sh deleted file mode 100755 index 18eb2b4733816..0000000000000 --- a/scripts/temp.sh +++ /dev/null @@ -1,7 +0,0 @@ -find ../caffe2 -name "*.py" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.h" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cc" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cpp" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cu" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.mm" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.m" -exec ./remove_apache_header.sh {} \; diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb deleted file mode 100644 index 0734167bdda11..0000000000000 --- a/scripts/xcode_build.rb +++ /dev/null @@ -1,76 +0,0 @@ -require 'optparse' -require 'xcodeproj' - -options = {} -option_parser = OptionParser.new do |opts| - opts.banner = 'Tools for building PyTorch iOS framework on MacOS' - opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| - options[:install] = value - } - opts.on('-x', '--xcodeproj_path ', 'path to the XCode project file') { |value| - options[:xcodeproj] = value - } - opts.on('-p', '--platform ', 'platform for the current build, OS or SIMULATOR') { |value| - options[:platform] = value - } -end.parse! -puts options.inspect - -install_path = File.expand_path(options[:install]) -if not Dir.exist? (install_path) - raise "path don't exist:#{install_path}!" -end -xcodeproj_path = File.expand_path(options[:xcodeproj]) -if not File.exist? (xcodeproj_path) - raise "path don't exist:#{xcodeproj_path}!" -end - -project = Xcodeproj::Project.open(xcodeproj_path) -target = project.targets.first #TestApp -header_search_path = ['$(inherited)', "#{install_path}/include"] -libraries_search_path = ['$(inherited)', "#{install_path}/lib"] -other_linker_flags = ['$(inherited)', "-all_load"] - -target.build_configurations.each do |config| - config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path - config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path - config.build_settings['OTHER_LDFLAGS'] = other_linker_flags - config.build_settings['ENABLE_BITCODE'] = 'No' -end - -# link static libraries -target.frameworks_build_phases.clear -libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libmicrokernels-prod.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a', 'libkineto.a'] -for lib in libs do - path = "#{install_path}/lib/#{lib}" - if File.exist?(path) - libref = project.frameworks_group.new_file(path) - target.frameworks_build_phases.add_file_reference(libref) - end -end -# link system frameworks -frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate', 'UIKit'] -if frameworks - frameworks.each do |framework| - path = "System/Library/Frameworks/#{framework}.framework" - framework_ref = project.frameworks_group.new_reference(path) - framework_ref.name = "#{framework}.framework" - framework_ref.source_tree = 'SDKROOT' - target.frameworks_build_phases.add_file_reference(framework_ref) - end -end -project.save - -sdk = nil -arch = nil -if options[:platform] == 'SIMULATOR' - sdk = 'iphonesimulator' - arch = 'arm64' -elsif options[:platform] == 'OS' - sdk = 'iphoneos' - arch = 'arm64' -else - raise "unsupported platform #{options[:platform]}" -end - -exec "xcodebuild clean build -project #{xcodeproj_path} -alltargets -sdk #{sdk} -configuration Release -arch #{arch}" From bfe5674e2294a6c73ff671116a91f6ae7220b3f8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 16:55:55 +0000 Subject: [PATCH 193/457] Revert "[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for `sm90`, `sm100` (#149282)" This reverts commit 0797b2b6a80cf70a7accc3d5413186e7693d4451. Reverted https://github.com/pytorch/pytorch/pull/149282 on behalf of https://github.com/wdvr due to reverting as discussed with @drisspg - @eqy please reach out to @drisspg for more info ([comment](https://github.com/pytorch/pytorch/pull/149282#issuecomment-3084759671)) --- aten/src/ATen/native/cudnn/MHA.cpp | 1055 ++++++----------- aten/src/ATen/native/cudnn/MHA.h | 27 - aten/src/ATen/native/native_functions.yaml | 6 - .../cuda/NestedTensorTransformerFunctions.cpp | 57 - .../native/transformers/cuda/attention.cu | 10 + .../transformers/cuda/attention_backward.cu | 192 +-- .../native/transformers/cuda/sdp_utils.cpp | 72 +- ...asDecompTest.test_has_decomposition.expect | 1 - test/inductor/test_cuda_repro.py | 8 +- test/test_nestedtensor.py | 9 +- test/test_transformers.py | 21 +- tools/autograd/derivatives.yaml | 4 - 12 files changed, 443 insertions(+), 1019 deletions(-) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 4f9e612e8e752..48119a6a3b4c3 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -84,37 +84,6 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset) { - TORCH_CHECK( - false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); -} - } // namespace native } // namespace at @@ -142,6 +111,40 @@ namespace native { #include namespace fe = cudnn_frontend; +using graph_and_tensors = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias + std::shared_ptr, // Attn_scale, + // TODO(eqy): additional options + // std::shared_ptr, // SEQ_LEN_Q, + // std::shared_ptr, // SEQ_LEN_KV, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + // std::shared_ptr, // Dropout_mask, + // std::shared_ptr, // Dropout_scale + std::shared_ptr, // O + std::shared_ptr // Stats + >; + +using graph_and_tensors_backward = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias, + std::shared_ptr, // Attn_scale, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + std::shared_ptr, // O, + std::shared_ptr, // dO, + std::shared_ptr, // stats, + std::shared_ptr, // dQ, + std::shared_ptr, // dK,, + std::shared_ptr // dV, + >; #define MAX_MHA_DIM 4 @@ -295,45 +298,11 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -// We also leak the caches to workaround potential teardown race issues. - -auto& getMHAGraphCache_() { - thread_local auto& instance = - *new MHAGraphCache, MHACacheKeyWrapper>; - return instance; -} - -auto& getMHAGraphBackwardCache_() { - thread_local auto& instance = - *new MHAGraphCache, MHACacheKeyWrapper>; - return instance; -} +thread_local MHAGraphCache mhagraphcache; +thread_local MHAGraphCache + mhagraphbackwardcache; namespace { - -enum UIDS { - Q, - K, - V, - O, - BIAS, - SCALE, - SEED, - OFFSET, - LSE, - DO, - DQ, - DK, - DV, - SEQ_LEN_Q, - SEQ_LEN_KV, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_LSE_OFF -}; - // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -351,10 +320,9 @@ auto fixSizeOneDimStrideSDPA( } return strides; } - } // namespace -auto build_graph( +auto build_graph_and_tensors( int64_t b, int64_t h, int64_t s_q, @@ -387,55 +355,46 @@ auto build_graph( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto scaled_dot_product_flash_attention_options = - fe::graph::SDPA_attributes() - .set_name("CUDNN_SDPA") - .set_is_inference(return_softmaxstats == false) - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutseed.dtype() == kInt + dropoutoffset.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - scaled_dot_product_flash_attention_options.set_dropout( - dropout_probability, seed, offset); - } - auto Q_ = mha_graph->tensor( + auto scaled_dot_product_flash_attention_options = + fe::graph::SDPA_attributes() + .set_name("CUDNN_SDPA") + .set_is_inference(return_softmaxstats == false) + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_dropout(dropout_probability, seed, offset); + auto Q = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(Q) .set_name("Q") .set_dim(q.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); - auto K_ = mha_graph->tensor( + auto K = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(K) .set_name("K") .set_dim(k.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); - auto V_ = mha_graph->tensor( + auto V = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(V) .set_name("V") .set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); @@ -443,20 +402,17 @@ auto build_graph( if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto [O_, Stats] = - mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); - O_->set_uid(O); - O_->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { - Stats->set_uid(LSE); Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -467,10 +423,20 @@ auto build_graph( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats)); } -auto build_graph_nestedtensor( +auto build_graph_and_tensors_nestedtensor( int64_t b, int64_t h_q, int64_t h_k, @@ -507,22 +473,28 @@ auto build_graph_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto SEQ_LEN_Q_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_Q) - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV_ = + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_KV) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -534,66 +506,41 @@ auto build_graph_nestedtensor( .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) - .set_seq_len_q(SEQ_LEN_Q_) - .set_seq_len_kv(SEQ_LEN_KV_) + .set_dropout(dropout_probability, seed, offset) + .set_seq_len_q(SEQ_LEN_Q) + .set_seq_len_kv(SEQ_LEN_KV) .set_padding_mask(true); - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - scaled_dot_product_flash_attention_options.set_dropout( - dropout_probability, seed, offset); - } // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); - // NB: cuDNN API shape is transposed constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; constexpr int strideidx2 = 2; - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -601,48 +548,44 @@ auto build_graph_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto RAG_Q_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_Q_OFF) - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_K_OFF) - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_V_OFF) - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_O_OFF) - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - Q_->set_ragged_offset(RAG_Q_OFF_); - K_->set_ragged_offset(RAG_K_OFF_); - V_->set_ragged_offset(RAG_V_OFF_); - auto [O_, Stats] = - mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); + auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("cum_seq_stats") + // .set_dim({b + 1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF = nullptr; + Q->set_ragged_offset(RAG_Q_OFF); + K->set_ragged_offset(RAG_K_OFF); + V->set_ragged_offset(RAG_V_OFF); + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); auto o_strides = o.strides(); - O_->set_output(true) - .set_uid(O) + O->set_output(true) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -650,20 +593,16 @@ auto build_graph_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); - O_->set_ragged_offset(RAG_O_OFF_); + O->set_ragged_offset(RAG_O_OFF); if (Stats) { - auto RAG_STATS_OFF = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_LSE_OFF) - .set_name("cum_seq_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + TORCH_CHECK( + false, + "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); + // TODO(eqy): fix when stats (backward) support is added Stats->set_output(true) - .set_uid(LSE) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q, 1, h_q, 1}); + .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -672,10 +611,27 @@ auto build_graph_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats), + std::move(RAG_Q_OFF), + std::move(RAG_K_OFF), + std::move(RAG_V_OFF), + std::move(RAG_O_OFF), + std::move(RAG_STATS_OFF), + std::move(SEQ_LEN_Q), + std::move(SEQ_LEN_KV)); } -auto build_graph_backward( +auto build_graph_and_tensors_backward( int64_t b, int64_t h, int64_t s_q, @@ -711,7 +667,6 @@ auto build_graph_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -721,327 +676,87 @@ auto build_graph_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim(q.sizes().vec()) - .set_stride(q.strides().vec())); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim(k.sizes().vec()) - .set_stride(k.strides().vec())); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim(v.sizes().vec()) - .set_stride(v.strides().vec())); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") + auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + + auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutseed.dtype() == kInt + dropoutoffset.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - sdpa_backward_options.set_dropout(dropout_probability, seed, offset); - } - auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(O) - .set_name("O") - .set_dim(o.sizes().vec()) - .set_stride(o.strides().vec())); - auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(LSE) + auto O = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Stats") .set_dim(softmaxstats.sizes().vec()) .set_stride(softmaxstats.strides().vec()) .set_data_type(fe::DataType_t::FLOAT)); - auto Do = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(DO) + auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("DO") .set_dim(dO.sizes().vec()) .set_stride(dO.strides().vec())); - auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( - Q_, K_, V_, O_, Do, Stats, sdpa_backward_options); - Dq->set_uid(DQ); - Dq->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - Dk->set_uid(DK); - Dk->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - Dv->set_uid(DV); - Dv->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); - AT_CUDNN_FRONTEND_CHECK( - mha_graph->create_execution_plans({fe::HeurMode_t::A})); - AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; -} - -auto build_graph_backward_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset, - cudnnHandle_t& handle) { - auto dtype = fe::DataType_t::HALF; - if (q.scalar_type() == kBFloat16) { - dtype = fe::DataType_t::BFLOAT16; - } - auto mha_graph = std::make_shared(); - // We're baking in float accumulation and scale types - // in theory the graph may support other types, but they - // have not been tested - mha_graph->set_io_data_type(dtype) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - auto attn_scale = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) - .set_name("Attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - auto SEQ_LEN_Q_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_Q) - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_KV) - .set_name("Seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("CUDNN_SDPA_NESTEDTENSOR_BACKWARD") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_seq_len_q(SEQ_LEN_Q_) - .set_seq_len_kv(SEQ_LEN_KV_) - .set_padding_mask(true); if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); } - auto q_strides = q.strides(); - auto k_strides = k.strides(); - auto v_strides = v.strides(); - // NB: cuDNN API shape is transposed - constexpr int strideidx0 = 1; - constexpr int strideidx1 = 0; - constexpr int strideidx2 = 2; - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); - auto o_strides = o.strides(); - auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(O) - .set_name("O") - .set_dim({b, h_q, s_q, d_v}) - .set_stride( - {INT_MAX, - o_strides[strideidx0], - o_strides[strideidx1], - o_strides[strideidx2]})); - - std::optional> bias; - if (attn_bias.has_value()) { - TORCH_CHECK( - false, - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); - bias = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) - .set_name("bias") - .set_dim(attn_bias.value().sizes().vec()) - .set_stride(attn_bias.value().strides().vec())); - sdpa_backward_options.set_bias(bias.value()); - } - auto RAG_Q_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_Q_OFF) - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_K_OFF) - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_V_OFF) - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_O_OFF) - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_LSE_OFF) - .set_name("cum_seq_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - O_->set_ragged_offset(RAG_O_OFF_); - Q_->set_ragged_offset(RAG_Q_OFF_); - K_->set_ragged_offset(RAG_K_OFF_); - V_->set_ragged_offset(RAG_V_OFF_); - auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(LSE) - .set_name("stats") - .set_dim({b, h_q, s_q, 1}) - .set_stride({s_q * h_q, 1, h_q, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - STATS->set_ragged_offset(RAG_STATS_OFF_); - auto do_strides = dO.strides(); - auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_ragged_offset(RAG_O_OFF_) - .set_uid(DO) - .set_name("DO") - .set_dim({b, h_q, s_q, d_v}) - .set_stride( - {INT_MAX, - do_strides[strideidx0], - do_strides[strideidx1], - do_strides[strideidx2]})); - auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( - Q_, K_, V_, O_, DO_, STATS, sdpa_backward_options); - Dq->set_output(true) - .set_uid(DQ) - .set_ragged_offset(RAG_Q_OFF_) - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]}); - Dk->set_output(true) - .set_uid(DK) - .set_ragged_offset(RAG_K_OFF_) - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]}); - Dv->set_output(true) - .set_uid(DV) - .set_ragged_offset(RAG_V_OFF_) - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]}); - + auto [DQ, DK, DV] = + mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); + DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(Seed), + std::move(Offset), + std::move(O), + std::move(DO), + std::move(STATS), + std::move(DQ), + std::move(DK), + std::move(DV)); } void run_cudnn_SDP_fprop( @@ -1102,12 +817,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats); - auto graph_ptr = getMHAGraphCache_().find(key); - std::shared_ptr mha_graph; - if (graph_ptr) { - mha_graph = *graph_ptr; + auto graph_and_tensors_ptr = mhagraphcache.find(key); + graph_and_tensors graph_and_tensors_values; + if (graph_and_tensors_ptr) { + graph_and_tensors_values = *graph_and_tensors_ptr; } else { - mha_graph = build_graph( + graph_and_tensors_values = build_graph_and_tensors( b, h, s_q, @@ -1128,28 +843,29 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } - std::unordered_map variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {SCALE, &scaling_factor}, - {O, o.data_ptr()}}; + auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = + graph_and_tensors_values; + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, _dropoutseed.data_ptr()}, + {offset, _dropoutoffset.data_ptr()}, + {O, o.data_ptr()}}; if (return_softmaxstats) { - variant_pack[LSE] = softmaxstats.data_ptr(); + variant_pack[Stats] = softmaxstats.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[BIAS] = attn_bias.value().data_ptr(); - } - if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + variant_pack[bias.value()] = attn_bias.value().data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - getMHAGraphCache_().update(key, mha_graph); + mhagraphcache.update(key, graph_and_tensors_values); } void run_cudnn_SDP_fprop_nestedtensor( @@ -1188,55 +904,72 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } - auto mha_graph = build_graph_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + seed, + offset, + O, + Stats, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_STATS_OFF, + SEQ_LEN_Q, + SEQ_LEN_KV] = + build_graph_and_tensors_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {SCALE, &scaling_factor}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, dropoutseed.data_ptr()}, + {offset, dropoutoffset.data_ptr()}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; if (return_softmaxstats) { - variant_pack[LSE] = softmaxstats.data_ptr(); - variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr(); - } - if (dropout_probability != 0.0f) { - variant_pack[SEED] = dropoutseed.data_ptr(); - variant_pack[OFFSET] = dropoutoffset.data_ptr(); + variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1320,12 +1053,12 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true); - auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key); - std::shared_ptr mha_graph; - if (graph_backward_ptr) { - mha_graph = *graph_backward_ptr; + auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); + graph_and_tensors_backward graph_and_tensors_backward_values; + if (graph_and_tensors_backward_ptr) { + graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; } else { - mha_graph = build_graph_backward( + graph_and_tensors_backward_values = build_graph_and_tensors_backward( b, h, s_q, @@ -1349,25 +1082,41 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } - std::unordered_map variant_pack = { - // inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {DO, dO_.data_ptr()}, - {LSE, softmaxstats.data_ptr()}, - // outputs - {DQ, dQ.data_ptr()}, - {DK, dK.data_ptr()}, - {DV, dV.data_ptr()}, - {SCALE, &scaling_factor}}; + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + Seed, + Offset, + O, + Do, + Stats, + Dq, + Dk, + Dv] = graph_and_tensors_backward_values; + std::unordered_map, void*> + variant_pack = {// inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {Do, dO_.data_ptr()}, + {Stats, softmaxstats.data_ptr()}, + // outputs + {Dq, dQ.data_ptr()}, + {Dk, dK.data_ptr()}, + {Dv, dV.data_ptr()}, + // pass by value + {attn_scale, &scaling_factor}}; if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + variant_pack[Seed] = _dropoutseed.data_ptr(); + variant_pack[Offset] = _dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[BIAS] = attn_bias.value().data_ptr(); + variant_pack[bias.value()] = attn_bias.value().data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = @@ -1375,127 +1124,7 @@ void run_cudnn_SDP_bprop( TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - getMHAGraphBackwardCache_().update(key, mha_graph); -} - -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset) { - // do nothing if we got 0-element tensors - if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || - !softmaxstats.numel()) { - return; - } - - Tensor dO_ = dO; - const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; - if (innermost_dO_stride != 1) { - permute_to_matching_layout(o, dO_); - } - - auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); - auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); - auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); - auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); - auto rag_stats_off = cum_seqlen_q.mul(h_q); - - auto dprops = at::cuda::getCurrentDeviceProperties(); - auto _dropoutseed = dropoutseed; - auto _dropoutoffset = dropoutoffset; - // cuDNN dropout bug requires these to be in int64 - if (dprops->major == 10 && dprops->minor == 0) { - _dropoutseed = dropoutseed.to(kLong); - _dropoutoffset = dropoutoffset.to(kLong); - } - - cudnnHandle_t handle = getCudnnHandle(); - - auto mha_graph = build_graph_backward_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - o, - dO_, - softmaxstats, - dQ, - dK, - dV, - dropoutseed, - dropoutoffset, - handle); - - std::unordered_map variant_pack = { - // inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {DO, dO_.data_ptr()}, - {LSE, softmaxstats.data_ptr()}, - // outputs - {DQ, dQ.data_ptr()}, - {DK, dK.data_ptr()}, - {DV, dV.data_ptr()}, - {SCALE, &scaling_factor}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {RAG_LSE_OFF, rag_stats_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; - if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); - } - TORCH_CHECK( - !attn_bias.has_value(), - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); - - auto workspace_size = mha_graph->get_workspace_size(); - auto workspace_ptr = - c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); - TORCH_CHECK(!workspace_size || workspace_ptr.get()); - TORCH_CHECK( - mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); + mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 620abc1aa0a8e..045e8cf6dee9d 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,31 +70,4 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset); - } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 09ea127555f98..79b7e07e2284b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14960,7 +14960,6 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda - NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -14993,11 +14992,6 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded -- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) - dispatch: - CUDA: _cudnn_attention_backward - tags: nondeterministic_seeded - - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 96c6ab8310f80..5b7476453407e 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,63 +349,6 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } -std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( - const Tensor& grad_out, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& out, - const Tensor& logsumexp, - const Tensor& philox_seed, - const Tensor& philox_offset, - const Tensor& attn_bias, - const Tensor& cum_seq_q, - const Tensor& cum_seq_k, - const int64_t max_q, - const int64_t max_k, - double dropout_p, - bool is_causal, - std::optional scale) { - if (!grad_out.defined()) { - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); - } - auto [ - grad_out_buffer_reshaped, - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - output_buffer_reshaped] = - preprocessing::sdpa_nested_preprocessing_backward( - grad_out, - query, - key, - value, - out, - cum_seq_q, - cum_seq_k, - max_q, - max_k); - - auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped, - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - output_buffer_reshaped, - logsumexp, - philox_seed, - philox_offset, - attn_bias, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p, - is_causal, - scale); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); -} - - std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 8513419db0a94..80049aa9a832f 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -849,6 +849,16 @@ std::tuple #include -#include -#include #include #include #include @@ -186,7 +184,7 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } -std::tuple _cudnn_attention_backward( +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -213,117 +211,57 @@ std::tuple _cudnn_attention_backward( } } - const bool is_nested = cum_seq_q.defined(); + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); - if (!is_nested) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); - - // This is needed because SaveVariable automatically converts - // std::optional to undefined tensor - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - } - } - - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - run_cudnn_SDP_bprop(batch_size /*int64_t b*/, - num_heads /*int64_t h*/, - max_q/*int64_t s_q*/, - max_k/*int64_t s_kv*/, - head_dim_qk /*int64_t d_qk*/, - head_dim_v /*int64_t d_v*/, - softmax_scale /*float scaling_factor*/, - is_causal /*bool is_causal*/, - dropout_p /*float dropout_probability*/, - query /*const Tensor& q*/, - key /*const Tensor& k*/, - value /*const Tensor& v*/, - attn_bias_ /*const std::optional& attn_bias*/, - out /*const Tensor& o*/, - grad_out/*const Tensor& dO*/, - logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, - dq/*Tensor& dQ*/, - dk/*Tensor& dK*/, - dv/*Tensor& dV*/, - philox_seed/*Tensor& dropoutseed*/, - philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); - } else { - // BHSD ... - const int64_t batch_size = cum_seq_q.size(0) - 1; - const int64_t num_heads_q = query.size(-2); - const int64_t num_heads_k = key.size(-2); - const int64_t num_heads_v = value.size(-2); - const int64_t head_dim_qk = query.size(-1); - const int64_t head_dim_v = value.size(-1); - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - } + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); } - - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - - const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); - run_cudnn_SDP_bprop_nestedtensor( - batch_size, - num_heads_q, - num_heads_k, - num_heads_v, - max_seqlen_batch_q, - max_seqlen_batch_k, - head_dim_qk, - head_dim_v, - softmax_scale, - is_causal, - dropout_p, - cum_seq_q, - cum_seq_k, - query, - key, - value, - attn_bias_, - out, - grad_out, - logsumexp, - dq, - dk, - dv, - philox_seed, - philox_offset); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } + + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } std::tuple @@ -1125,40 +1063,4 @@ std::tuple _scaled_dot_product_e } } -std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( - const Tensor& grad_out, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& out, - const Tensor& logsumexp, - const Tensor& philox_seed, - const Tensor& philox_offset, - const Tensor& attn_bias, - const Tensor& cum_seq_q, - const Tensor& cum_seq_k, - const int64_t max_q, - const int64_t max_k, - double dropout_p, - bool is_causal, - std::optional scale) { - return at::_cudnn_attention_backward( - grad_out, - query, - key, - value, - out, - logsumexp, - philox_seed, - philox_offset, - attn_bias, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p, - is_causal, - scale); -} - } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 4b85b2d28753a..4b198f4d6d2de 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -57,28 +57,21 @@ namespace sdp { namespace { -// tracks whether we've set the default priority order once, to avoid setting -// it redundantly or overwriting a user-specified priority order -// when the priority order context manager is used before the default priority -// order is initialized the following happens: -// (1) the current priority order is queried -// (2) priority_order() is called, which initializes it to the default as init_ is false -// (3) the user-specified priority order is set -// (3.1) we are in the priority context... -// (3.2) we exit the priority context... -// (4) the previous priority order (default) is restored -bool priority_order_init_ = false; - // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { - static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; - if (!prefer_cudnn) { - return false; - } -#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000)) + // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 + // see context: https://github.com/pytorch/pytorch/issues/138340 + // return false; +#if defined(CUDNN_VERSION) + +#if CUDNN_VERSION > 90000 auto dprops = at::cuda::getCurrentDeviceProperties(); - return dprops->major >= 9 && !dprops->minor; + return dprops->major >= 9; +#else + return false; +#endif + #else return false; #endif @@ -86,16 +79,6 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { - if (!priority_order_init_) { - priority_order_init_ = true; - if (check_prefer_cudnn_attention()) { - const std::vector cudnn_order = {static_cast(at::SDPBackend::cudnn_attention), - static_cast(at::SDPBackend::flash_attention), - static_cast(at::SDPBackend::efficient_attention), - static_cast(at::SDPBackend::math)}; - at::globalContext().setSDPPriorityOrder(cudnn_order); - } - } return at::globalContext().sDPPriorityOrder(); } @@ -431,7 +414,12 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } auto head_dim_limit = 128; - // TODO(eqy): add head dim >= 256 cases once support is finalized + if (cudnn_version >= 90501) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 9 && !dprops->minor) { + head_dim_limit = 256; + } + } if (d_qk > head_dim_limit || d_v > head_dim_limit) { if (debug) { TORCH_WARN("head_dim should be no more than ", head_dim_limit); @@ -465,15 +453,9 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } - if (s_k == 1) { - if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); - } - return false; - } - if (s_q == 1 && params.dropout != 0.0) { + if (s_q == 1 || s_k == 1) { if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); + TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); } return false; } @@ -581,9 +563,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { + if (dprop->major != 9 && has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); + TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); } return false; } @@ -607,7 +589,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { - TORCH_WARN("cuDNN attention has been runtime disabled."); + TORCH_WARN("CuDNN attention has been runtime disabled."); } return false; } @@ -638,7 +620,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); } return false; #endif @@ -648,8 +630,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, + check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, + check_cudnn_tensor_shapes, check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -662,10 +646,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( - check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense, - check_cudnn_tensor_shapes + check_batch_size_and_num_heads_dense ); if (has_only_dense_inputs(params)) { @@ -882,7 +864,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_WARN("cuDNN attention kernel not used because:"); + TORCH_WARN("CuDNN attention kernel not used because:"); sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ee9d466f60832..042959c22cd4a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,7 +75,6 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out -aten::_cudnn_attention_backward aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 6007e3f3171f5..bb59b626bef14 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -26,7 +26,6 @@ run_fw_bw_and_get_code, ) from torch.fx.experimental.proxy_tensor import make_fx -from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -178,10 +177,9 @@ def test_effn_attn_bias_padding_misaligned(self): inputs = [q, k, v, mask] def f(q, k, v, mask): - with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): - return F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) f_compiled = torch.compile(f) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6dbb0b2cdad85..0e0234b089412 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6751,10 +6751,11 @@ def check_forward_backward(skip_backward=False): and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() + with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") diff --git a/test/test_transformers.py b/test/test_transformers.py index 85c9f4a07cec5..7c11cb2833d74 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -49,6 +49,7 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, + SM90OrLater, tf32_on_and_off, tf32_enabled, ) @@ -2628,7 +2629,6 @@ def test_cudnn_attention_gqa(self, device): @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") - @unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) @@ -2639,7 +2639,7 @@ def test_cudnn_attention_d256_heuristic(self, device): v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True): + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True): actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) actual.backward(torch.randn_like(actual)) @@ -2677,7 +2677,7 @@ def test_fused_attention_different_dk_dv(self, device): @skipIfRocm # No cuDNN Attention - @unittest.skipIf(True, "broken as of cuDNN 9.10") + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 b, h = 1, 2 @@ -2692,6 +2692,7 @@ def test_cudnn_attention_fail_d128(self, device): ISSM90 = device_cap == (9, 0) ISSM100 = device_cap == (10, 0) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + # SM90/100 support d <= 256 as of cuDNN 9.5.1+ if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501: torch.nn.functional.scaled_dot_product_attention(q, k, v) else: @@ -3127,19 +3128,15 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - device_capability = None - if "cuda" in str(device): - device_capability = torch.cuda.get_device_capability() - prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ - prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0)) - # TODO we are currently disabling this by default, lets assert that this returns # FlashAttention, we need to change when we make remove opt-in for cudnn - if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7af47591bd08b..e2419aab268b1 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2896,10 +2896,6 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) -- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - output_differentiability: [True, False, False, False, False, False, False, False, False] - query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) - - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) From 6d31d38965ef0bc81f3a5a49882d200c69218ccf Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Thu, 17 Jul 2025 17:00:14 +0000 Subject: [PATCH 194/457] recovering node source from dict (#158373) (#158473) Summary: this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde. Test Plan: ci Rollback Plan: Differential Revision: D78434648 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158473 Approved by: https://github.com/angelayi --- test/fx/test_fx_traceback.py | 28 +++++++++++++++---- torch/fx/traceback.py | 54 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index f02bc5a2e1592..05369d17078ba 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -32,6 +32,8 @@ def test_node_source(self): dummy_source_dict, ) + self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict())) + # Dummy node node = torch.fx.Node( graph=torch.fx.Graph(), @@ -179,14 +181,28 @@ def forward(self, x): if node_name_1 in same_ancestor_nodes else None, }: - self.assertTrue( - node_name_to_from_node[node_name_1] - == node_name_to_from_node[node_name_2] + self.assertEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], ) else: - self.assertTrue( - node_name_to_from_node[node_name_1] - != node_name_to_from_node[node_name_2] + self.assertNotEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertNotEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], ) gm = ep.module() diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 59187fedccfaa..648a80b87b681 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -153,6 +153,60 @@ def _make_hashable(obj): return hash(_make_hashable(self.to_dict())) + @classmethod + def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: + """ + Recursively deserialize from_node metadata from dictionary data. + It is used to deserialize the from_node field from serialized metadata. + Please use constructor NodeSource(node, ...) to create a NodeSource object. + """ + if d is None: + return None + + assert isinstance(d, dict), f"Expected a dict, got {type(d)}" + + # Create a NodeSource object directly without going through the constructor + # to avoid issues with graph ID and node creation + node_source = NodeSource.__new__(NodeSource) + + # Reset the cached properties + node_source._action_string = None + node_source._dict = None + + # Set the basic attributes + node_source.pass_name = d.get("pass_name", "") + + # Parse action string back to NodeSourceAction enum list + action_str = d.get("action", "") + actions = [] + if action_str: + for action_name in action_str.split("+"): + if action_name.upper() == "CREATE": + actions.append(NodeSourceAction.CREATE) + elif action_name.upper() == "REPLACE": + actions.append(NodeSourceAction.REPLACE) + node_source.action = actions + + # Create the NodeInfo object directly + if "name" in d and "target" in d and "graph_id" in d: + node_info = NodeSource.NodeInfo( + d.get("name", ""), d.get("target", ""), d.get("graph_id", -1) + ) + node_source.node_info = node_info + else: + node_source.node_info = None + + # Recursively deserialize nested from_node + if d.get("from_node", None) is not None: + node_source.from_node = [ + result + for fn in d.get("from_node", []) + if (result := cls._from_dict(fn)) is not None + ] + else: + node_source.from_node = [] + return node_source + @compatibility(is_backward_compatible=False) @contextmanager From bff69f25c2e98adc2e4a765d9fa47f230e2fef45 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 16 Jul 2025 10:24:22 -0700 Subject: [PATCH 195/457] [BE][testing] fix test/dynamo/test_repros:test_longtensor_list (#158458) Summary: This test is failing internally because the number of underlying calls to the rng differ by virtue of various library initializations that get sucked in with an internal build. Test Plan: `buck test '@fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_repros.py::ReproTests::test_longtensor_list' --run-disabled` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158458 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 8636b3496d1b5..bdb297789dc6b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2036,8 +2036,13 @@ def fn(x): ref0 = fn(x) ref1 = fn(x) - random.seed(0) opt_fn = torch.compile(fn, backend="eager") + # Especially for internal usage, there are many calls to random functions + # on first compile, e.g., from various library initializations. Run once + # to get that out of the way before resetting the seed: + opt_fn(x) + + random.seed(0) res0 = opt_fn(x) res1 = opt_fn(x) From f92a2035e41699b026abc25a1a8dde6971bfe477 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Thu, 17 Jul 2025 10:39:17 -0700 Subject: [PATCH 196/457] ci: Update lint workflow to only run on changed files for PRs (#158518) This modifies the lint workflow to use the new get-changed-files workflow to optimize lint execution by only running on files that have actually changed in pull requests. This more closely mirrors the type of behavior that users expect when running lint locally on their PRs. This also leaves the default behavior as a fallback for when you're not running on a pull request. Since lint runs on the pull_request event I'm not really worried about any type of ciflow shenanigans in this. This also splits mypy into its own job since mypy needs to run on all-files all the time. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/158518 Approved by: https://github.com/huydhn ghstack dependencies: #158517 --- .github/workflows/lint.yml | 45 +++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 66cd5f653446b..1a21a68a865da 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,9 +26,15 @@ jobs: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} + + get-changed-files: + if: github.repository_owner == 'pytorch' + name: Get changed files + uses: ./.github/workflows/_get-changed-files.yml + lintrunner-clang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - needs: get-label-type + needs: [get-label-type, get-changed-files] with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" @@ -39,13 +45,37 @@ jobs: submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + if [ "$CHANGED_FILES" = "*" ]; then + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + else + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT $CHANGED_FILES" + fi export CLANG=1 .github/scripts/lintrunner.sh + # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes + # fails to find types when it should + lintrunner-mypy: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + needs: [get-label-type, get-changed-files] + with: + timeout: 120 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + docker-image: ci-image:pytorch-linux-jammy-linter + # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout + # to run git rev-parse HEAD~:.ci/docker when a new image is needed + fetch-depth: 0 + submodules: true + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running mypy" + ADDITIONAL_LINTRUNNER_ARGS="--take MYPY --all-files" .github/scripts/lintrunner.sh + lintrunner-noclang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - needs: get-label-type + needs: [get-label-type, get-changed-files] with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" @@ -56,8 +86,13 @@ jobs: submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - export ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT --all-files" - .github/scripts/lintrunner.sh + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running all other linters" + if [ "$CHANGED_FILES" = '*' ]; then + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY --all-files" .github/scripts/lintrunner.sh + else + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY ${CHANGED_FILES}" .github/scripts/lintrunner.sh + fi quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main From ad223a6c5fec7e143f3c0fd56a492f0a79f61711 Mon Sep 17 00:00:00 2001 From: Kevin Fu Date: Thu, 17 Jul 2025 18:09:56 +0000 Subject: [PATCH 197/457] Add FP8 Types (#158430) Summary: Add FP8 Types Test Plan: sandcastle Rollback Plan: Differential Revision: D78395110 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158430 Approved by: https://github.com/henryoier --- torch/_export/serde/export_schema.thrift | 4 +++- torch/_export/serde/schema.py | 4 +++- torch/_export/serde/schema.yaml | 6 ++++-- torch/_export/serde/serialize.py | 2 ++ torch/csrc/utils/generated_serialization_types.h | 8 +++++++- torch/nativert/graph/TensorMeta.cpp | 4 ++++ 6 files changed, 23 insertions(+), 5 deletions(-) diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index c4f5809af0b47..50472c02375cc 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<31664e4faa0eacd6f538ffed163078e190d9d2b98d762dd45b68eb1b7b12f0d1>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -50,6 +50,8 @@ enum ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, } diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 32c69140807b7..933d30310b72c 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 8) +SCHEMA_VERSION = (8, 9) TREESPEC_VERSION = 1 @@ -33,6 +33,8 @@ class ScalarType(IntEnum): UINT16 = 28 FLOAT8E4M3FN = 29 FLOAT8E5M2 = 30 + FLOAT8E4M3FNUZ = 31 + FLOAT8E5M2FNUZ = 32 class Layout(IntEnum): diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 19145e7f8e326..9167a6820ef40 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +# checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> AOTInductorModelPickleData: kind: struct fields: @@ -420,6 +420,8 @@ ScalarType: UINT16: 28 FLOAT8E4M3FN: 29 FLOAT8E5M2: 30 + FLOAT8E4M3FNUZ: 31 + FLOAT8E5M2FNUZ: 32 SchemaVersion: kind: struct fields: @@ -532,5 +534,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 8 +- 9 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 5c688b2a14d24..710311d31f6e3 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -143,6 +143,8 @@ def _reverse_map(d: dict[Any, Enum]): torch.bfloat16: ScalarType.BFLOAT16, torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, + torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, } diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 67fd1ecf05c02..98803390e5104 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +// checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> // clang-format off #pragma once @@ -283,6 +283,8 @@ enum class ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, }; inline std::string_view printEnum(const ScalarType& e) { @@ -304,6 +306,8 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::UINT16: return "UINT16"; case ScalarType::FLOAT8E4M3FN: return "FLOAT8E4M3FN"; case ScalarType::FLOAT8E5M2: return "FLOAT8E5M2"; + case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; + case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; default: throw std::runtime_error("Unknown enum value"); } @@ -327,6 +331,8 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "UINT16") { t = ScalarType::UINT16; return; } if (s == "FLOAT8E4M3FN") { t = ScalarType::FLOAT8E4M3FN; return; } if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } + if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } + if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } throw std::runtime_error("Unknown enum value: " + std::string{s}); } diff --git a/torch/nativert/graph/TensorMeta.cpp b/torch/nativert/graph/TensorMeta.cpp index 81625dca116f9..97afbc9f095e6 100644 --- a/torch/nativert/graph/TensorMeta.cpp +++ b/torch/nativert/graph/TensorMeta.cpp @@ -41,6 +41,10 @@ c10::ScalarType convertJsonScalarType( return c10::ScalarType::Float8_e4m3fn; case torch::_export::ScalarType::FLOAT8E5M2: return c10::ScalarType::Float8_e5m2; + case torch::_export::ScalarType::FLOAT8E4M3FNUZ: + return c10::ScalarType::Float8_e4m3fnuz; + case torch::_export::ScalarType::FLOAT8E5M2FNUZ: + return c10::ScalarType::Float8_e5m2fnuz; default: TORCH_CHECK(false, "unknown scalar type", static_cast(scalarType)); } From 25f4d7e48271eb4d2f1dbdb4a6380b2c00339b5e Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 17 Jul 2025 18:46:39 +0000 Subject: [PATCH 198/457] Use new type statement to fix public API of types (#158487) Since type statement breaks older python version, trying to find equivalent behavior without the type mechanics. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158487 Approved by: https://github.com/andrewor14 --- torch/ao/quantization/__init__.py | 10 +++++++-- torch/ao/quantization/qconfig.py | 8 +++++-- torch/ao/quantization/utils.py | 35 +++++++++++++++++++++++++------ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index cf5a8b99a8941..f50b9d6cd137e 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -33,9 +33,15 @@ # ensure __module__ is set correctly for public APIs -ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] -if sys.version_info < (3, 14): +if sys.version_info < (3, 12): + ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +else: + from typing import TypeAliasType + + ObserverOrFakeQuantize = TypeAliasType( + "ObserverOrFakeQuantize", Union[ObserverBase, FakeQuantizeBase] + ) for _f in [ compare_results, diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index d9a8fc78bab4a..94dfdb8c7626a 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -568,9 +568,13 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N ) -QConfigAny = Optional[QConfig] -if sys.version_info < (3, 14): +if sys.version_info < (3, 12): + QConfigAny = Optional[QConfig] QConfigAny.__module__ = "torch.ao.quantization.qconfig" +else: + from typing import TypeAliasType + + QConfigAny = TypeAliasType("QConfigAny", Optional[QConfig]) def _add_module_to_qconfig_obs_ctr( diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index a80ae1d8e3de1..e93cd3fdb7cbd 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -16,9 +16,16 @@ from torch.nn.utils.parametrize import is_parametrized -NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] -if sys.version_info < (3, 14): +if sys.version_info < (3, 12): + NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] NodePattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + NodePattern = TypeAliasType( + "NodePattern", Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] + ) + # This is the Quantizer class instance from torch/quantization/fx/quantize.py. # Define separately to prevent circular imports. @@ -30,11 +37,27 @@ # Type for fusion patterns, it can be more complicated than the following actually, # see pattern.md for docs # TODO: not sure if typing supports recursive data types -Pattern = Union[ - Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any -] -if sys.version_info < (3, 14): + +if sys.version_info < (3, 12): + Pattern = Union[ + Callable, + tuple[Callable, Callable], + tuple[Callable, tuple[Callable, Callable]], + Any, + ] Pattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + Pattern = TypeAliasType( + "Pattern", + Union[ + Callable, + tuple[Callable, Callable], + tuple[Callable, tuple[Callable, Callable]], + Any, + ], + ) # TODO: maybe rename this to MatchInputNode From 7e34f9c292940e16e06f0b85fce99c14af708569 Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 17 Jul 2025 19:01:49 +0000 Subject: [PATCH 199/457] Add torch._C._log_api_usage_once to datapipes (mapper) (#155489) This is to get a better understanding of how datapipes is used right now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155489 Approved by: https://github.com/ramanishsingh --- torch/utils/data/datapipes/iter/callable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index c96ce82cf139a..718e728c9389d 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -4,6 +4,7 @@ from collections.abc import Iterator, Sized from typing import Any, Callable, Optional, TypeVar, Union +import torch from torch.utils.data._utils.collate import default_collate from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper @@ -74,6 +75,7 @@ def __init__( input_col=None, output_col=None, ) -> None: + torch._C._log_api_usage_once("python.data_pipes.map") super().__init__() self.datapipe = datapipe From 8dcebaa7b088f9ae8c08975310e63c81a154153f Mon Sep 17 00:00:00 2001 From: Xu Han Date: Thu, 17 Jul 2025 19:22:56 +0000 Subject: [PATCH 200/457] [AOTI] add WIN32 implement for create_temp_dir (#158570) add Windows implement for `create_temp_dir`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158570 Approved by: https://github.com/angelayi --- .../inductor/aoti_package/model_package_loader.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index cca993406f7f3..8c674764d9dc4 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -47,7 +47,16 @@ bool file_exists(const std::string& path) { std::string create_temp_dir() { #ifdef _WIN32 - throw std::runtime_error("Not implemented"); + try { + fs::path temp_dir = fs::temp_directory_path(); + return temp_dir.string(); + } catch (const fs::filesystem_error& e) { + throw std::runtime_error( + "Failed to get temporary directory: " + std::string(e.what())); + } catch (...) { + throw std::runtime_error( + "Unknown error occurred while getting temporary directory"); + } #else std::string temp_dir = "/tmp/XXXXXX"; if (mkdtemp(temp_dir.data()) == nullptr) { From 7ebbf2cae7e55d5f64a15a1e8912e55ff0a6c9a4 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Thu, 17 Jul 2025 19:47:41 +0000 Subject: [PATCH 201/457] Revert "[PT2][fusion] ban fusions with large accumulated reads (#157563) (#158550) This reverts commit 8554c8007ddaa8029e7e01bb1af12f358bf597c2 #157563 due to causing a few breakages on ROCm Reverted expected_results.csv to 26807dcf277feb2d99ab88d7b6da526488baea93 > @xuanzhang816 Sorry, but I have to revert this PR yet again because it clearly reintroduced failures on ROCm after the remerge: https://hud.pytorch.org/hud/pytorch/pytorch/f4d8bc46c7706f872abcb4ec41f0b32207d5d826/2?per_page=50&name_filter=rocm-mi300&mergeEphemeralLF=true and the failures are still showing up on tip-of-tree on HUD Context https://github.com/pytorch/pytorch/pull/157563#issuecomment-3083350857 Needs to be relanded in non bc-breaking way, or sanity checked for correctness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158550 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily --- .../pr_time_benchmarks/expected_results.csv | 46 ++++++++-------- test/inductor/test_inplace_padding.py | 2 - test/inductor/test_memory.py | 53 ------------------- test/inductor/test_online_softmax.py | 8 +-- torch/_inductor/choices.py | 4 -- torch/_inductor/config.py | 1 - torch/_inductor/graph.py | 21 -------- torch/_inductor/ir.py | 11 ---- torch/_inductor/memory.py | 13 ++++- torch/_inductor/scheduler.py | 29 +++++----- 10 files changed, 53 insertions(+), 135 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 24f0b2af088c2..edc9d0f73d161 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,89 +1,89 @@ -add_loop_eager,compile_time_instruction_count,3051000000,0.015 +add_loop_eager,compile_time_instruction_count,3017000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,4405000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 -add_loop_inductor,compile_time_instruction_count,33810000000,0.015 +add_loop_inductor,compile_time_instruction_count,29490000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43470000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,30390000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,965100000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18300000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17630000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10980000000,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 -update_hint_regression,compile_time_instruction_count,1717000000,0.02 +update_hint_regression,compile_time_instruction_count,1673000000,0.02 -sum_floordiv_regression,compile_time_instruction_count,965000000,0.015 +sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 -symint_sum,compile_time_instruction_count,3239000000,0.015 +symint_sum,compile_time_instruction_count,3166000000,0.015 -symint_sum_loop,compile_time_instruction_count,4305000000,0.015 +symint_sum_loop,compile_time_instruction_count,4202000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2146000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6119000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8976000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1988000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3951000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10640000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 -mm_loop_inductor_gpu,compile_time_instruction_count,4468000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8400000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 -basic_NestedModule_eager,compile_time_instruction_count,8357000000,0.015 +basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 -basic_InlineMod_eager,compile_time_instruction_count,7443000000,0.015 +basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index 0207134dc7013..80cb86ec417d4 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -9,7 +9,6 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, @@ -210,7 +209,6 @@ def f(x, y): self.assertEqual(num_inplace_padding(), 0) - @serialTest() @requires_cuda_with_enough_memory(2e10) @inductor_config.patch(force_shape_pad=True) def test_linear_and_cel(self): diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 3e23442b38ec7..eaff539f7a493 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -8,7 +8,6 @@ from torch._inductor import config, memory from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_triton_code -from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -307,58 +306,6 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) - @serialTest() - def test_fusion_acc_large_reads(self): - def f(x, y, z): - res = torch.zeros_like(x[0]) - for i in range(4): - temp = torch.matmul(x, y) + z - res = res + temp - return res - - N = 128 - x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) - - # CASE 1: no restriction on the amount of accumulation - with config.patch({"realize_acc_reads_size_threshold": float("inf")}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") - .run(code) - ) - - # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) - # at most 12 / 4 = 3 reads can be accumulated during fusion - with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") - .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") - .run(code) - ) - - # CASE 3: no such fusion allowed - with config.patch({"realize_acc_reads_size_threshold": N**2}): - f_compiled = torch.compile(f) - code = run_and_get_triton_code(f_compiled, x, y, z) - ( - FileCheck() - .check("triton_poi_fused_add_0.run(buf1, arg2_1,") - .check("triton_poi_fused_add_0.run(buf3, arg2_1,") - .check("triton_poi_fused_add_0.run(buf4, buf3,") - .check("triton_poi_fused_add_0.run(buf6, arg2_1,") - .check("triton_poi_fused_add_0.run(buf7, buf6,") - .check("triton_poi_fused_add_0.run(buf9, arg2_1,") - .check("triton_poi_fused_add_0.run(buf10, buf9,") - .run(code) - ) - if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index 37959c241113f..798d86b0dd617 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -13,7 +13,6 @@ instantiate_parametrized_tests, IS_LINUX, parametrize, - serialTest, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA @@ -78,17 +77,12 @@ def f(x): out, source_codes = run_and_get_code(f, x) return source_codes[0] - @serialTest() def test_codegen_3pass_softmax_due_to_disable(self): - with inductor_config.patch( - online_softmax=False, - realize_acc_reads_size_threshold=float("inf"), - ): + with inductor_config.patch(online_softmax=False): wrapper_code = self.get_softmax_wrapper() self.assertEqual(wrapper_code.count("for r0_offset in"), 3) - @serialTest() @parametrize("V", [2048, 50304]) @parametrize("use_log_softmax", [False, True]) def test_codegen_online_softmax(self, use_log_softmax, V): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 9096ba6dd0393..b7bab02da5e4b 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,10 +365,6 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False - if scheduler.fusion_accumulate_large_reads(node1, node2): - WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") - return False - return True @staticmethod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f4c54e2812674..5eb2b57a225a4 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -570,7 +570,6 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 -realize_acc_reads_size_threshold = 3 * (1024**3) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index ac299d5b0c2d0..e2cc101533f28 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,7 +123,6 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen - from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -486,9 +485,6 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() - # Cache for dep size hints to avoid expensive recomputation - self.dep_size_hint_cache: dict[Dep, int] = {} - def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -574,23 +570,6 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) - def get_dep_size_hint(self, dep: Dep) -> int: - """ - Get the size hint for a dependency with caching to avoid expensive recomputation. - """ - if dep not in self.dep_size_hint_cache: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.dep_size_hint_cache[dep] = res - return self.dep_size_hint_cache[dep] - def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d6dd82aa52f2d..1edbb214ae2ad 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7829,10 +7829,6 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): - """ - StorageBox allow in-place mutation of Tensors - """ - def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7882,17 +7878,10 @@ def realize_hint(self) -> None: ): self.realize() - def has_accumulated_enough_reads_by_size(self) -> bool: - return ( - sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) - > config.realize_acc_reads_size_threshold - ) - def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() - or self.has_accumulated_enough_reads_by_size() ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index d287208419a9f..5601bc4adcee4 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,8 +78,19 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ + # this function is copied from torch/_inductor/scheduler.py + # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - return V.graph.get_dep_size_hint(dep) + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + return res # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 34f15869085f0..5c7a16d25bc64 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,12 +2051,15 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ + __dep_size_hint_cache: dict[Dep, int] + def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() + self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3502,17 +3505,6 @@ def _find_single_user_inputs( return True return False - def fusion_accumulate_large_reads( - self, node1: BaseSchedulerNode, node2: BaseSchedulerNode - ) -> bool: - all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( - node1.read_writes.writes | node2.read_writes.writes - ) - return ( - sum(self.dep_size_hint(dep) for dep in all_reads) - > config.realize_acc_reads_size_threshold - ) - def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4018,7 +4010,20 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - return V.graph.get_dep_size_hint(dep) + res = 0 + if dep not in self.__dep_size_hint_cache: + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.__dep_size_hint_cache[dep] = res + else: + res = self.__dep_size_hint_cache[dep] + return res def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode From fd51bcdd21683c715a9b4ef9340c90753964c76f Mon Sep 17 00:00:00 2001 From: Wouter Devriendt Date: Thu, 17 Jul 2025 19:48:22 +0000 Subject: [PATCH 202/457] check if USE_ROCM is defined (#158571) Summary: check if USE_ROCM is defined D78424375 broke some builds: see T231304402 Test Plan: rerunning failed builds Rollback Plan: Reviewed By: Camyll Differential Revision: D78493019 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158571 Approved by: https://github.com/huydhn, https://github.com/malfet --- c10/cuda/CUDAFunctions.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 9379e626d2cfd..543c866027464 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -90,7 +90,7 @@ C10_CUDA_API void __inline__ memcpy_and_sync( (*interp)->trace_gpu_stream_synchronization( c10::kCUDA, reinterpret_cast(stream)); } -#if USE_ROCM +#if defined(USE_ROCM) && USE_ROCM // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of // hipMemcpyWithStream which is a synchronous call. Thus, we add a check // here explicitly. From ef256ad17b7b4fd9b79c4b580b4023f2c50eef11 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 17 Jul 2025 12:42:48 -0400 Subject: [PATCH 203/457] Make Inductor imports TYPE_CHECKING only (#158524) Signed-off-by: Edward Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158524 Approved by: https://github.com/cyyever, https://github.com/albanD --- torch/_functorch/_aot_autograd/schemas.py | 37 ++++++++++++++++------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 78f8e506e07e1..f8b60d4f7060c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -4,23 +4,28 @@ input/output types, metadata, config, function signatures etc. """ +from __future__ import annotations + import collections -import contextlib import dataclasses import functools import itertools -from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, NewType, Optional, Protocol, TypeVar, Union +from typing import ( + Any, + Callable, + NewType, + Optional, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) import torch import torch.utils._pytree as pytree from torch import Tensor -from torch._guards import Source -from torch._inductor.output_code import OutputCode -from torch._inductor.utils import InputType -from torch._ops import OpOverload from torch._subclasses import FakeTensor from torch._subclasses.fake_tensor import is_fake from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -33,6 +38,16 @@ from .utils import strict_zip +if TYPE_CHECKING: + import contextlib + from collections.abc import Iterable, Sequence + + from torch._guards import Source + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + from torch._ops import OpOverload + + zip = strict_zip @@ -170,7 +185,7 @@ class MemoryFormatMeta: memory_format: Optional[torch.memory_format] = None @staticmethod - def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]: + def from_tensor(t: torch.Tensor) -> Optional[MemoryFormatMeta]: # We only memorize expected memory format for # 1. Traceable wrapper subclasses # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. @@ -236,7 +251,7 @@ class SubclassCreationMeta: # meta and attrs are produced by the subclass's __tensor_flatten__. # We need to keep them around along with outer_size / outer_stride to plumb them # into __tensor_unflatten__ - attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]] + attrs: dict[str, Union[SubclassCreationMeta, PlainTensorMeta]] outer_size: Iterable[Union[None, int, torch.SymInt]] outer_stride: Iterable[Union[None, int, torch.SymInt]] meta: Any @@ -832,7 +847,7 @@ def from_tracing_metadata( num_user_outputs: int, loss_index: Optional[int], backward_signature: Optional[BackwardSignature], - ) -> "GraphSignature": + ) -> GraphSignature: graph_inputs = graph_input_names graph_outputs = graph_output_names parameters = list(named_parameters) @@ -1103,7 +1118,7 @@ class AOTGraphCapture: # Produced by aot_stage1_graph_capture FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) -TOutputCode = TypeVar("TOutputCode", bound=OutputCode) +TOutputCode = TypeVar("TOutputCode", bound="OutputCode") class AOTDispatchCompiler(Protocol): From eeb0783fe6357fd59a91b65c5dba0a00b21506b7 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 17 Jul 2025 05:48:35 -0700 Subject: [PATCH 204/457] [simple_fsdp][inductor_collectives] rewrite reorder_collectives, sink_waits_iterative (#158062) Differential Revision: [D78159013](https://our.internmc.facebook.com/intern/diff/D78159013) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158062 Approved by: https://github.com/wconstab --- test/distributed/test_inductor_collectives.py | 76 ++- torch/_inductor/comms.py | 474 +++++++++++------- torch/_inductor/dependencies.py | 8 +- 3 files changed, 353 insertions(+), 205 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fad2f8195600c..1f09d72ea2b1a 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -19,6 +19,7 @@ from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, + sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.scheduler import BaseSchedulerNode @@ -1621,7 +1622,7 @@ def test_reorder_peak_memory_bucketed(self): comm from moving due to data dependency. """ - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1654,14 +1655,52 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): # wait op rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) + y += torch.mm(2 * x, 2 * w) + + # cast the inputs + ag_2_cast = ag_2.to(torch.bfloat16) + ag_3_cast = ag_3.to(torch.bfloat16) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2_cast, group_size, group_name + ) + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3_cast, group_size, group_name + ) + + # wait op + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + + # + rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_2_cast, "sum", group_size, group_name + ) + rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_3_cast, "sum", group_size, group_name + ) - return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out + # wait op + rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out) + rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out) + return ( + y, + ag_0_out, + ag_1_out, + ag_2_out, + ag_3_out, + rs_0_out, + rs_1_out, + rs_2_out, + rs_3_out, + ) x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] # get stats directly from the internal helper without affecting the real pass's signature node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None @@ -1679,11 +1718,15 @@ def _reorder_communication_preserving_peak_memory( with torch._inductor.config.patch( { "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], + "allow_buffer_reuse": False, } ): compiled = torch.compile(func) @@ -1694,30 +1737,29 @@ def _reorder_communication_preserving_peak_memory( FileCheck() .check_count( "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - count=1, + count=2, exactly=True, ) + .check( + "extern_kernels.mm", + ) + .check( + "extern_kernels.addmm", + ) .run(code) ) ( FileCheck() .check_count( "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - count=1, + count=2, exactly=True, ) - .run(code) - ) - ( - FileCheck() - .check( - "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - ) .check( - "torch.ops._c10d_functional.reduce_scatter_tensor.default(", + "extern_kernels.mm", ) .check( - "extern_kernels.mm", + "extern_kernels.addmm", ) .run(code) ) @@ -1726,7 +1768,7 @@ def _reorder_communication_preserving_peak_memory( assert same(out, correct), f"{out} va {correct}" assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) - self.assertEqual(len(node_stats), 2) + self.assertEqual(len(node_stats), 4) it = iter(node_stats.values()) node_stat0 = next(it) self.assertTrue(node_stat0.moves > 0) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index caaf43dba5904..f93485333d303 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,7 +4,6 @@ import heapq import importlib -import itertools import logging import operator import sys @@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: return True if ( - hasattr(node, "python_kernel_name") - and node.python_kernel_name == "extern_kernels.mm" - ): + python_kernel_name := getattr(node, "python_kernel_name", None) + ) and "extern_kernels" in python_kernel_name: return True return False @@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str: def _reorder_communication_preserving_peak_memory_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node - - original_snodes_num = len(snodes) """ Internal testing helper that also returns debug info. Returns: - reordered snodes list - dict {snode: ReorderInfo} """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) # heuristic to avoid degenerating to quadratic time graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) @@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal( snodes, name_to_freeable_input_buf, graph_outputs ) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} - snode_to_curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] # debug stats stats: dict[BaseSchedulerNode, ReorderInfo] = {} @@ -232,153 +239,151 @@ def accumulate_time(_snode): _temp_group_visit_leaves(snode, accumulate_time) return max(0, comm_time - compute_time) - MOVE_LIMIT = len(snodes) * 100 total_moves = 0 - # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it - PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) - if config.reorder_prefetch_limit is not None: - PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit # Dicts to keep track of "next" and "previous" as double-linked structure during grouping - _prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} - _next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} for i, snode in enumerate(snodes): _prev[snode] = snodes[i - 1] if i > 0 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None - - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i, gsnode in enumerate(gsnodes): - snode = gsnode.snodes[0] # type: ignore[attr-defined] - if contains_collective(snode): - reorder_info = stats[snode] = ReorderInfo() + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + _head = snodes[0] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = _head + while _next[curr] is not None: + if contains_collective(curr): + reorder_info = stats[curr] = ReorderInfo() reorder_info.initial_exposed = reorder_info.final_exposed = ( - exposed_communication_time(snode, snodes[i + 1 :]) + exposed_communication_time(curr, _group_nodes(_next[curr], None)) ) - if total_moves >= MOVE_LIMIT: - reorder_info.limiting_factor = "move limit" - continue - for j in range(i - 1, -1, -1): - prev_gsnode = gsnodes[j] - if len(prev_gsnode.snodes) == 0: - continue - - if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): - reorder_info.limiting_factor = "prefetch limit" - break - if contains_collective(prev_gsnode): + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + if contains_collective(candidate): reorder_info.limiting_factor = "collective ordering" break - dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) - prev_outs = prev_gsnode.get_outputs() + group = GroupedSchedulerNode( + curr.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + + data_deps = {s.name: s for s in group.unmet_dependencies} + candidate_outs = candidate.get_outputs() data_dep = None - for o in prev_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break if data_dep is not None: - def is_groupable(prev_gsnode): + def is_groupable(candidate): # preserve ordering - if contains_collective(prev_gsnode): - return False - - if contains_gemm_like(prev_gsnode): - return False - return True - - if is_groupable(prev_gsnode): - new_snodes = prev_gsnode.snodes + gsnode.snodes - init_group_node(gsnode, gsnode.scheduler, new_snodes) - prev_gsnode.snodes = [] + if contains_collective(candidate): + return False, "contains_collective" + + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_head = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) reorder_info.grouped += 1 - reorder_info.grouped_info = gsnode.get_name() + reorder_info.grouped_info = _group_names(group_head, group_tail) + candidate = _prev[candidate] continue else: msg = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n non_group_reason:{grp_reason}" ) reorder_info.limiting_factor = msg break - if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + delta_memory_candidate = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + + if group_peak_memory - delta_memory_candidate > peak_memory: reorder_info.limiting_factor = "peak memory" break - if reorder_info.final_exposed > runtimes[snode]: - reorder_info.limiting_factor = "sufficient overlapping" - break + reorder_info.moves += 1 total_moves += 1 - # swapping nodes j and j+1 affects curr memory at j only - # j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] - # j_alloc = curr_memory[j] - curr_memory[j - 1] - # curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc - def swap_curr_memory_with_previous( - snode_j_plus_one, snode_j, snode_j_minus_one - ): - curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] - curr_memory_j = snode_to_curr_memory[snode_j] - curr_memory_j_minus_one = ( - snode_to_curr_memory[snode_j_minus_one] - if snode_j_minus_one is not None - else 0 - ) - j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j - j_alloc = curr_memory_j - curr_memory_j_minus_one - snode_to_curr_memory[snode_j] = ( - curr_memory_j - j_alloc + j_plus_one_alloc - ) - - # Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) - # swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] - # decomposing to: - # swap(A2, B0) -> A0, A1, B0, A2, B1 - # swap(A2, B1) -> A0, A1, B0, B1, A2 - # swap(A1, B0) -> A0, B0, A1, B1, A2 - # swap(A1, B1) -> A0, B0, B1, A1, A2 - # swap(A0, B0) -> B0, A0, B1, A1, A2 - # swap(A0, B1) -> B0, B1, A0, A1, A2 - for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A - snode_j = gsnodes[j].snodes[_j] - for _i, snode_i in enumerate(gsnode.snodes): # group B - swap_curr_memory_with_previous( - snode_j_plus_one=snode_i, - snode_j=snode_j, - snode_j_minus_one=_prev[snode_j], - ) + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # swap (candidate, group_head...group_tail) + # Before: + # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next + # After: + # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next + # 0 + candidate_prev = _prev[candidate] + if candidate_prev: + _next[candidate_prev] = group_head + _prev[group_head] = candidate_prev + + # 2 + group_tail_next = _next[group_tail] + if group_tail_next: + _prev[group_tail_next] = candidate + _next[candidate] = group_tail_next + + # 1 + _prev[candidate] = group_tail + _next[group_tail] = candidate + + if _head == candidate: + _head = group_head - # Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] - first = snode_j - second = snode_i - first_prev = _prev[first] - second_next = _next[second] - if first_prev: - _next[first_prev] = second - _prev[second] = first_prev - - if second_next: - _prev[second_next] = first - _next[first] = second_next - - _next[second] = first - _prev[first] = second - - tmp = gsnodes[j] - gsnodes[j] = gsnodes[j + 1] - gsnodes[j + 1] = tmp reorder_info.final_exposed = exposed_communication_time( - snode, - itertools.chain( - gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]] - ), + curr, _group_nodes(_next[curr], None) ) + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] + for n in _group_nodes(group_head, candidate): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + candidate = _prev[group_head] + curr = _next[curr] # type: ignore[assignment] node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} @@ -426,17 +431,13 @@ def swap_curr_memory_with_previous( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - grouping_logs: list[str] = [] - flatten_gsnodes: list[BaseSchedulerNode] = [] - for i, gsnode in enumerate(gsnodes): - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_gsnodes.extend(gsnode.snodes) - else: - flatten_gsnodes.append(gsnode) - - grouping_log_str = "\n".join(grouping_logs) - reorder_log_str += "\n" - reorder_log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + reorder_log_str += f"\n peak_memory_before:{peak_memory}" + reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" overlap_log.info(reorder_log_str) trace_structured( @@ -448,8 +449,7 @@ def swap_curr_memory_with_previous( payload_fn=lambda: reorder_log_str, ) - assert len(flatten_gsnodes) == original_snodes_num - return flatten_gsnodes, stats + return new_snodes, stats def _schedule_for_comm( @@ -623,7 +623,9 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): - comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + comm_nodes[i].add_fake_dep( + WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) + ) return nodes @@ -640,66 +642,166 @@ class SinkWaitInfo: def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) - n = len(snodes) stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i in range(n - 1, -1, -1): - gsnode = gsnodes[i] - if contains_wait(gsnode): - info = stats[gsnode.snodes[0]] = SinkWaitInfo() - for j in range(i + 1, n): - wait_gsnode = gsnodes[j - 1] - wait_outs = wait_gsnode.get_outputs() - next_gsnode = gsnodes[j] - dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _head = snodes[0] + for i, snode in enumerate(snodes): + _prev[snode] = snodes[i - 1] if i > 0 else None + _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + while _prev[curr] is not None: + if contains_wait(curr) and curr not in processed_waits: + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + group = GroupedSchedulerNode( + wait_snode.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + group_outs = group.get_outputs() + + data_deps = {s.name: s for s in candidate.unmet_dependencies} data_dep = None - for o in wait_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break # 1. If we have data_dep - we can not swap => trying to group # 2. If swap candidate and current node both contain collectives => trying to group if data_dep is not None or ( both_contain_comms := ( - contains_collective(wait_gsnode) - and contains_collective(next_gsnode) + contains_collective(group) and contains_collective(candidate) ) ): def is_groupable(snode): - return not contains_gemm_like(snode) - - if is_groupable(next_gsnode): - new_snodes = wait_gsnode.snodes + next_gsnode.snodes - init_group_node(next_gsnode, gsnode.scheduler, new_snodes) - wait_gsnode.snodes = [] + # We do not want to group with collectives to not reorder them forward. + if contains_collective(snode): + return ( + False, + f"candidate contains collective {snode.get_name()}", + ) + if contains_gemm_like(snode): + return ( + False, + f"candidate contains gemm_like {snode.get_name()}", + ) + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_tail = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) info.grouped += 1 - info.grouped_info = _group_name(next_gsnode) + info.grouped_info = _group_names(group_head, group_tail) + candidate = _next[candidate] continue elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( - f"collective ordering {_group_name(wait_gsnode)}" - f" with candidate:{_group_name(next_gsnode)}" + f"collective ordering {_group_names(group_head, group_tail)}" + f" with candidate:{candidate.get_name()}" ) + break else: info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}" - f" outs:{[o.get_name() for o in wait_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{grp_reason}" ) break + candidate_delta_memory = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + if group_peak_memory + candidate_delta_memory > peak_memory: + info.limiting_factor = "peak_memory" + break + info.moves += 1 - info.moves_info += f"+{_group_name(next_gsnode)}" + info.moves_info += f"+{candidate.get_name()}" + + # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # 0: + group_head_prev = _prev[group_head] + if group_head_prev: + _next[group_head_prev] = candidate + _prev[candidate] = group_head_prev + + # 2: + candidate_next = _next[candidate] + if candidate_next: + _prev[candidate_next] = group_tail + _next[group_tail] = candidate_next + + # 1: + _prev[group_head] = candidate + _next[candidate] = group_head + if group_head == _head: + _head = candidate + + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] + for n in _group_nodes(candidate, group_tail): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + + candidate = _next[group_tail] + curr = _prev[curr] # type: ignore[assignment] - # Swapping snodes j and j - 1 - tmp = gsnodes[j - 1] - gsnodes[j - 1] = gsnodes[j] - gsnodes[j] = tmp headers = [ "Wait node", "grouped", @@ -732,16 +834,13 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - grouping_logs = [] - flatten_snodes = [] - for i, gsnode in enumerate(gsnodes): - grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_snodes.extend(gsnode.snodes) - else: - flatten_snodes.append(gsnode) - grouping_log_str = "\n".join(grouping_logs) - log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + log_str += f"\n peak_memory_before:{peak_memory}" + log_str += f"\n peak_memory_after:{new_peak_memory}" trace_structured( "artifact", metadata_fn=lambda: { @@ -750,8 +849,7 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - assert len(flatten_snodes) == n - return flatten_snodes, stats + return new_snodes, stats def sink_waits_iterative( @@ -777,7 +875,9 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): - detail = f" ({snode.node.python_kernel_name})" + outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" + ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1352,7 +1452,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_ag_wait = wait_group_node @@ -1364,7 +1464,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 9de52061c6489..8a374f5bab35c 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -342,6 +342,12 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str + # WeakDep's are also used to add dependencies to prevent some specific reordering, + # E.g. collectives global ordering. + # But if other pass guarantees proper ordering by its logic, + # This additional "fake" deps will be holding optimizations. + # This flag is used to identify those additional deps. + is_fake: bool = False @property def index(self) -> sympy.Expr: @@ -352,7 +358,7 @@ def get_numel(self) -> sympy.Expr: def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: - return WeakDep(renames[self.name], self.mutating_buf) + return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) return self def numbytes_hint(self) -> int: From 80ac73c0575d860993beab58bb718e727c82bc22 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 15 Jul 2025 22:33:00 -0700 Subject: [PATCH 205/457] [ca] reset between tests (#158418) CA reset is much faster than dynamo reset, so it's probably okay to run it every time. I'm not sure if this will fix the flaky autograd tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158418 Approved by: https://github.com/jansel --- torch/testing/_internal/common_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 692b71660071d..1b4f03da3dfc1 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3376,6 +3376,8 @@ def wrapper(*args, **kwargs): if strict_mode or should_reset_dynamo: torch._dynamo.reset() + elif torch._dynamo.config.compiled_autograd: + torch._dynamo.compiled_autograd.reset() # Early terminate test if necessary. If using pytest, use the -x flag instead if using_unittest and self._should_stop_test_suite(): From 66c9bc5062503da58991a1fb0a9eab5d501b2891 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 17 Jul 2025 20:15:19 +0000 Subject: [PATCH 206/457] [export] Add runnable code to export docs (#158506) Preview: https://docs-preview.pytorch.org/pytorch/pytorch/158506/export.html Yay I can add runnable code to export docs now Also moved export API reference to a different file. With these changes, we can start to consolidate the [export tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) with the docs on pytorch docs. We just need to move the section on DDE and 0/1 specialization, and then I think we can delete the export tutorial. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158506 Approved by: https://github.com/pianpwk, https://github.com/svekars --- .ci/docker/requirements-docs.txt | 1 + docs/source/conf.py | 2 +- docs/source/export.md | 460 ++++++++++++++-------------- docs/source/export/api_reference.md | 69 +++++ 4 files changed, 294 insertions(+), 238 deletions(-) create mode 100644 docs/source/export/api_reference.md diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 54e9dbdfca266..73ec471c88464 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -59,3 +59,4 @@ sphinx-copybutton==0.5.0 sphinx-design==0.4.0 sphinxcontrib-mermaid==1.0.0 myst-parser==0.18.1 +myst-nb diff --git a/docs/source/conf.py b/docs/source/conf.py index a19d6b7102a3e..d19d9ec21ef8e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,7 @@ "sphinxcontrib.katex", "sphinx_copybutton", "sphinx_design", - "myst_parser", + "myst_nb", "sphinx.ext.linkcode", "sphinxcontrib.mermaid", "sphinx_sitemap", diff --git a/docs/source/export.md b/docs/source/export.md index 0f0deebc65108..fcebcc6d49620 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -1,3 +1,13 @@ +--- +file_format: mystnb +kernelspec: + name: python3 +mystnb: + execution_timeout: 30 + execution_show_tb: True + merge_streams: True +--- + (torch.export)= # torch.export @@ -9,9 +19,9 @@ representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized. -```python +```{code-cell} import torch -from torch.export import export +from torch.export import export, ExportedProgram class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -21,53 +31,10 @@ class Mod(torch.nn.Module): example_args = (torch.randn(10, 10), torch.randn(10, 10)) -exported_program: torch.export.ExportedProgram = export( - Mod(), args=example_args -) +exported_program: ExportedProgram = export(Mod(), args=example_args) print(exported_program) ``` -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): - # code: a = torch.sin(x) - sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) - - # code: b = torch.cos(y) - cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) - - # code: return a + b - add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) - return (add,) - - Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='y'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='add'), - target=None - ) - ] - ) - Range constraints: {} -``` - `torch.export` produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found {ref}`here `. @@ -142,9 +109,9 @@ The main entrypoint is through {func}`torch.export.export`, which takes a captures the computation graph into an {class}`torch.export.ExportedProgram`. An example: -```python +```{code-cell} import torch -from torch.export import export +from torch.export import export, ExportedProgram # Simple module for demonstration class M(torch.nn.Module): @@ -164,64 +131,13 @@ class M(torch.nn.Module): example_args = (torch.randn(1, 3, 256, 256),) example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} -exported_program: torch.export.ExportedProgram = export( +exported_program: ExportedProgram = export( M(), args=example_args, kwargs=example_kwargs ) print(exported_program) -``` -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): - # code: a = self.conv(x) - conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) - - # code: a.add_(constant) - add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) - - # code: return self.maxpool(self.relu(a)) - relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) - max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) - return (max_pool2d,) - -Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_weight'), - target='conv.weight', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_bias'), - target='conv.bias', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='constant'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='max_pool2d'), - target=None - ) - ] - ) -Range constraints: {} +# To run the exported program, we can use the `module()` method +print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256))) ``` Inspecting the `ExportedProgram`, we can note the following: @@ -242,17 +158,15 @@ Inspecting the `ExportedProgram`, we can note the following: ## Expressing Dynamism By default `torch.export` will trace the program assuming all input shapes are -**static**, and specializing the exported program to those dimensions. However, -some dimensions, such as a batch dimension, can be dynamic and vary from run to -run. Such dimensions must be specified by using the -{func}`torch.export.Dim` API to create them and by passing them into -{func}`torch.export.export` through the `dynamic_shapes` argument. +**static**, and specializing the exported program to those dimensions. One +consequence of this is that at runtime, the program won’t work on inputs with +different shapes, even if they’re valid in eager mode. An example: -```python +```{code-cell} import torch -from torch.export import Dim, export +import traceback as tb class M(torch.nn.Module): def __init__(self): @@ -273,43 +187,64 @@ class M(torch.nn.Module): example_args = (torch.randn(32, 64), torch.randn(32, 128)) -# Create a dynamic batch size -batch = Dim("batch") -# Specify that the first dimension of each input is that batch size -dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}} +ep = torch.export.export(M(), example_args) +print(ep) -exported_program: torch.export.ExportedProgram = export( - M(), args=example_args, dynamic_shapes=dynamic_shapes -) -print(exported_program) +example_args2 = (torch.randn(64, 64), torch.randn(64, 128)) +try: + ep.module()(*example_args2) # fails +except Exception: + tb.print_exc() ``` -```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): - # code: out1 = self.branch1(x1) - linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) - relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) +However, some dimensions, such as a batch dimension, can be dynamic and vary +from run to run. Such dimensions must be specified by using the +{func}`torch.export.Dim()` API to create them and by passing them into +{func}`torch.export.export()` through the `dynamic_shapes` argument. + +```{code-cell} +import torch + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + + self.branch1 = torch.nn.Sequential( + torch.nn.Linear(64, 32), torch.nn.ReLU() + ) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) - # code: out2 = self.branch2(x2) - linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) - relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) +example_args = (torch.randn(32, 64), torch.randn(32, 128)) - # code: return (out1 + self.buffer, out2) - add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) - return (add, relu_1) +# Create a dynamic batch size +batch = torch.export.Dim("batch") +# Specify that the first dimension of each input is that batch size +dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} -Range constraints: {s0: VR[0, int_oo]} +ep = torch.export.export( + M(), args=example_args, dynamic_shapes=dynamic_shapes +) +print(ep) + +example_args2 = (torch.randn(64, 64), torch.randn(64, 128)) +ep.module()(*example_args2) # success ``` Some additional things to note: - Through the {func}`torch.export.Dim` API and the `dynamic_shapes` argument, we specified the first dimension of each input to be dynamic. Looking at the inputs `x1` and - `x2`, they have a symbolic shape of (s0, 64) and (s0, 128), instead of - the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. + `x2`, they have a symbolic shape of `(s0, 64)` and `(s0, 128)`, instead of + the `(32, 64)` and `(32, 128)` shaped tensors that we passed in as example inputs. `s0` is a symbol representing that this dimension can be a range of values. - `exported_program.range_constraints` describes the ranges of each symbol @@ -419,13 +354,29 @@ utility {class}`torch.export.ShapesCollection`, where instead of specifying the dynamism of every single input, we can just assign directly which input dimensions are dynamic. -```python -dim = torch.export.Dim(...) -dynamic_shapes = torch.export.ShapesCollection() -dynamic_shapes[tensor_x] = (dim, dim + 1, 8) -dynamic_shapes[tensor_y] = {0: dim * 2} +```{code-cell} +import torch -torch.export(..., args, dynamic_shapes=dynamic_shapes) +class M(torch.nn.Module): + def forward(self, inp): + x = inp["x"] * 1 + y = inp["others"][0] * 2 + z = inp["others"][1] * 3 + return x, y, z + +tensor_x = torch.randn(3, 4, 8) +tensor_y = torch.randn(6) +tensor_z = torch.randn(6) +args = {"x": tensor_x, "others": [tensor_y, tensor_z]} + +dim = torch.export.Dim("dim") +sc = torch.export.ShapesCollection() +sc[tensor_x] = (dim, dim + 1, 8) +sc[tensor_y] = {0: dim * 2} + +print(sc.dynamic_shapes(M(), (args,))) +ep = torch.export.export(M(), (args,), dynamic_shapes=sc) +print(ep) ``` ### AdditionalInputs @@ -440,16 +391,33 @@ shapes are changing. Example: -```python -args0, kwargs0 = ... # example inputs for export +```{code-cell} +import dataclasses +import torch +import torch.utils._pytree as pytree -# other representative inputs that the exported program will run on -dynamic_shapes = torch.export.AdditionalInputs() -dynamic_shapes.add(args1, kwargs1) -... -dynamic_shapes.add(argsN, kwargsN) +@dataclasses.dataclass +class D: + b: bool + i: int + f: float + t: torch.Tensor -torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) +pytree.register_dataclass(D) + +class M(torch.nn.Module): + def forward(self, d: D): + return d.i + d.f + d.t + +input1 = (D(True, 3, 3.0, torch.ones(3)),) +input2 = (D(True, 4, 3.0, torch.ones(4)),) +ai = torch.export.AdditionalInputs() +ai.add(input1) +ai.add(input2) + +print(ai.dynamic_shapes(M(), input1)) +ep = torch.export.export(M(), input1, dynamic_shapes=ai) +print(ep) ``` ## Serialization @@ -463,13 +431,12 @@ An example: ```python import torch -import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 -exported_program = torch.export.export(MyModule(), torch.randn(5)) +exported_program = torch.export.export(MyModule(), (torch.randn(5),)) torch.export.save(exported_program, 'exported_program.pt2') saved_exported_program = torch.export.load('exported_program.pt2') @@ -479,30 +446,109 @@ saved_exported_program = torch.export.load('exported_program.pt2') ## Export IR, Decompositions -The graph produced by `torch.export` returns a graph containing only ATen -operators, which are the basic unit of computation in PyTorch. As there are over +The graph produced by `torch.export` returns a graph containing only +[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic unit of +computation in PyTorch. As there are over 3000 ATen operators, export provides a way to narrow down the operator set used in the graph based on certain characteristics, creating different IRs. By default, export produces the most generic IR which contains all ATen operators, including both functional and non-functional operators. A functional operator is one that does not contain any mutations or aliasing of the inputs. -This operator set also allows you to train in eager PyTorch Autograd. +You can find a list of all ATen operators +[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) +and you can inspect if an operator is functional by checking +`op._schema.is_mutable`. + +This generic IR can be used to train in eager PyTorch Autograd. + +```{code-cell} +import torch + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +print(ep_for_training.graph_module.print_readable(print_output=False)) +``` However, if you want to use the IR for inference, or decrease the amount of -operators being used, you can lower the graph through the {func}`ExportedProgram.run_decompositions` API. +operators being used, you can lower the graph through the +{func}`ExportedProgram.run_decompositions` API. This method decomposes the +ATen operators into the ones specified in the decomposition table, and +functionalizes the graph. -* By specifying an empty set to the `decomp_table` argument, we get rid of all - non-functional operators, reducing the operator set to ~2000 operators. This - is ideal for inference cases as there are no mutations or aliasing, making - it easy to write optimization passes. -* By specifying None to `decomp_table` argument, we can reduce the operator set - to just the {ref}`Core ATen Operator Set `, which is a - collection of only ~180 operators. This IR is optimal for backends who do - not want to reimplement all ATen operators. +By specifying an empty set, we're only performing functionalization, and does +not do any additional decompositions. This results in an IR which contains ~2000 +operators (instead of the 3000 operators above), and is ideal for inference cases. -```python -class ConvBatchnorm(torch.nn.Module): +```{code-cell} +import torch + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +with torch.no_grad(): + ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) +print(ep_for_inference.graph_module.print_readable(print_output=False)) +``` + +As we can see, the previously in-place operator, +`torch.ops.aten.add_.default` has now been replaced with +`torch.ops.aten.add.default`, a functional operator. + +We can also further lower this exported program to an operator set which only +contains the +`Core ATen Operator Set `__, +which is a collection of only ~180 operators. This IR is optimal for backends +who do not want to reimplement all ATen operators. + +```{code-cell} +import torch + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +with torch.no_grad(): + core_aten_ir = ep_for_training.run_decompositions(decomp_table=None) +print(core_aten_ir.graph_module.print_readable(print_output=False)) +``` + +We now see that `torch.ops.aten.conv2d.default` has been decomposed +into `torch.ops.aten.convolution.default`. This is because `convolution` +is a more "core" operator, as operations like `conv1d` and `conv2d` can be +implemented using the same op. + +We can also specify our own decomposition behaviors: + +```{code-cell} +class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 3, 1, 1) @@ -513,15 +559,22 @@ class ConvBatchnorm(torch.nn.Module): x = self.bn(x) return (x,) -mod = ConvBatchnorm() -inp = torch.randn(1, 1, 3, 3) +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) -ep_for_training = torch.export.export(mod, (inp,)) -ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) +my_decomp_table = torch.export.default_decompositions() + +def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + +my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function +my_ep = ep_for_training.run_decompositions(my_decomp_table) +print(my_ep.graph_module.print_readable(print_output=False)) ``` -A tutorial on how to use this API can be found -[here](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html#ir-decompositions). +Notice that instead of `torch.ops.aten.conv2d.default` being decomposed +into `torch.ops.aten.convolution.default`, it is now decomposed into +`torch.ops.aten.convolution.default` and `torch.ops.aten.mul.Tensor`, +which matches our custom decomposition rule. (limitations-of-torch-export)= @@ -587,6 +640,7 @@ have a FakeTensor kernel implementation yet, please file an issue. :caption: Additional Links for Export Users :maxdepth: 1 +export/api_reference export/programming_model export/ir_spec export/pt2_archive @@ -605,71 +659,3 @@ torch.compiler_dynamic_shapes torch.compiler_fake_tensor torch.compiler_transformations ``` - -## API Reference - -```{eval-rst} -.. automodule:: torch.export - -.. autofunction:: torch.export.export - -.. autoclass:: torch.export.ExportedProgram - :members: - :exclude-members: __init__ - -.. automodule:: torch.export.dynamic_shapes - :members: Dim, ShapesCollection, AdditionalInputs, refine_dynamic_shapes_from_suggested_fixes - -.. autofunction:: torch.export.save - -.. autofunction:: torch.export.load - -.. autofunction:: torch.export.pt2_archive._package.package_pt2 - -.. autofunction:: torch.export.pt2_archive._package.load_pt2 - -.. autofunction:: torch.export.draft_export - -.. automodule:: torch.export.unflatten - :members: - -.. autofunction:: torch.export.register_dataclass - -.. automodule:: torch.export.decomp_utils - :members: - :ignore-module-all: - :undoc-members: - -.. automodule:: torch.export.experimental - :members: - :ignore-module-all: - -.. automodule:: torch.export.passes - :members: - -.. automodule:: torch.export.pt2_archive - :members: - :ignore-module-all: - -.. automodule:: torch.export.pt2_archive.constants - :members: - :ignore-module-all: - -.. automodule:: torch.export.exported_program - :members: - :ignore-module-all: - :exclude-members: ExportedProgram - -.. automodule:: torch.export.custom_ops - :members: - :ignore-module-all: - -.. automodule:: torch.export.custom_obj - :members: - :ignore-module-all: - -.. automodule:: torch.export.graph_signature - :members: - :ignore-module-all: - :undoc-members: -``` diff --git a/docs/source/export/api_reference.md b/docs/source/export/api_reference.md new file mode 100644 index 0000000000000..f729e84e261d0 --- /dev/null +++ b/docs/source/export/api_reference.md @@ -0,0 +1,69 @@ +(export.api_reference)= + +# torch.export API Reference + +```{eval-rst} +.. automodule:: torch.export + +.. autofunction:: torch.export.export + +.. autoclass:: torch.export.ExportedProgram + :members: + :exclude-members: __init__ + +.. automodule:: torch.export.dynamic_shapes + :members: Dim, ShapesCollection, AdditionalInputs, refine_dynamic_shapes_from_suggested_fixes + +.. autofunction:: torch.export.save + +.. autofunction:: torch.export.load + +.. autofunction:: torch.export.pt2_archive._package.package_pt2 + +.. autofunction:: torch.export.pt2_archive._package.load_pt2 + +.. autofunction:: torch.export.draft_export + +.. automodule:: torch.export.unflatten + :members: + +.. autofunction:: torch.export.register_dataclass + +.. automodule:: torch.export.decomp_utils + :members: + :ignore-module-all: + :undoc-members: + +.. automodule:: torch.export.experimental + :members: + :ignore-module-all: + +.. automodule:: torch.export.passes + :members: + +.. automodule:: torch.export.pt2_archive + :members: + :ignore-module-all: + +.. automodule:: torch.export.pt2_archive.constants + :members: + :ignore-module-all: + +.. automodule:: torch.export.exported_program + :members: + :ignore-module-all: + :exclude-members: ExportedProgram + +.. automodule:: torch.export.custom_ops + :members: + :ignore-module-all: + +.. automodule:: torch.export.custom_obj + :members: + :ignore-module-all: + +.. automodule:: torch.export.graph_signature + :members: + :ignore-module-all: + :undoc-members: +``` From 1b88da1cac30dec473cbdca4d9efb9b117cb8cdb Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 15 Jul 2025 17:04:00 -0500 Subject: [PATCH 207/457] [MPS] Improve performance of max_pool3d (#157875) To check how the changes from this PR affect performance, I wrote a script here: https://github.com/kurtamohler/pytorch-perf-test-scripts/blob/55ef32a127c746d13d7310375068a6b300bda92d/max_pool_mps/perf.py. Before this PR, I get this: ``` =================== max_pool3d =================== 0: 0.013105 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2} 1: 0.038003 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5} 2: 0.212963 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5} 3: 1.224645 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5} 4: 7.317867 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1} 5: 34.679233 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20} 6: 34.626383 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1} 7: 44.835892 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40} 8: 0.083579 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2} 9: 0.936575 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2} 10: 5.329883 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2} 11: 11.713617 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2} 12: 25.450454 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2} 13: 0.058375 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2} 14: 3.757558 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2} 15: 33.451588 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2} ``` After this PR, I get this: ``` =================== max_pool3d =================== 0: 0.007202 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2} 1: 0.018596 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5} 2: 0.130717 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5} 3: 0.966795 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5} 4: 4.095804 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1} 5: 12.833446 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20} 6: 12.859346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1} 7: 14.080529 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40} 8: 0.029283 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2} 9: 0.175700 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2} 10: 0.742750 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2} 11: 1.939596 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2} 12: 4.074821 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2} 13: 0.028425 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2} 14: 0.384375 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2} 15: 2.623346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2} ``` Every case is improved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157875 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/kernels/Pooling.h | 35 +-- .../src/ATen/native/mps/kernels/Pooling.metal | 287 ++++++++++-------- .../src/ATen/native/mps/operations/Pooling.mm | 77 ++--- 3 files changed, 222 insertions(+), 177 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/Pooling.h b/aten/src/ATen/native/mps/kernels/Pooling.h index 1d366f9620db4..d72131bd40874 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.h +++ b/aten/src/ATen/native/mps/kernels/Pooling.h @@ -5,29 +5,30 @@ // maximum allowed pooling dimensions is N-2, because the input may have up to 2 // leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is // the default. -template +template struct PoolingParams { int32_t dims; int32_t pooling_dims; - ::c10::metal::array input_sizes; - ::c10::metal::array input_strides; - ::c10::metal::array output_sizes; - ::c10::metal::array output_strides; - ::c10::metal::array indices_sizes; - ::c10::metal::array indices_strides; - ::c10::metal::array kernel_size; - ::c10::metal::array stride; - ::c10::metal::array padding; - ::c10::metal::array dilation; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array indices_sizes; + ::c10::metal::array indices_strides; + ::c10::metal::array kernel_size; + ::c10::metal::array stride; + ::c10::metal::array padding; + ::c10::metal::array dilation; + bool return_indices; }; -template +template struct PoolingBackwardParams { int32_t dims; int32_t pooling_dims; - ::c10::metal::array grad_input_sizes; - ::c10::metal::array grad_input_strides; - ::c10::metal::array grad_output_sizes; - ::c10::metal::array grad_output_strides; - ::c10::metal::array indices_strides; + ::c10::metal::array grad_input_sizes; + ::c10::metal::array grad_input_strides; + ::c10::metal::array grad_output_sizes; + ::c10::metal::array grad_output_strides; + ::c10::metal::array indices_strides; }; diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal index 92a22c97f017b..18982559a34b8 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.metal +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -6,6 +6,28 @@ using namespace metal; using namespace c10::metal; +template +struct IterBounds { + T start; + T end; +}; + +template +IterBounds get_input_iter_bounds( + constant int32_t* input_sizes, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation) { + auto d = dilation[dim]; + auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim]; + auto end = min(start + kernel_size[dim] * d, input_sizes[dim]); + auto start_correction = d * ((-start - 1 + d) / d); + start += start < 0 ? start_correction : 0; + return IterBounds{start, end}; +} + // Iterates through all the input elements that this kernel needs to // apply max to. Specialized for 3 pooling dimensions. // TODO: Support any number of pooling dims @@ -14,82 +36,62 @@ void max_pool_3d_input_iter( constant T* input, device T* output, device int64_t* indices, - constant int64_t* input_sizes, - constant int64_t* input_strides, - device int64_t* work_pooling_dim_indices, - constant int64_t* kernel_size, - constant int64_t* stride, - constant int64_t* padding, - constant int64_t* dilation) { - int64_t o0 = work_pooling_dim_indices[0]; - int64_t o1 = work_pooling_dim_indices[1]; - int64_t o2 = work_pooling_dim_indices[2]; - - int64_t k0 = kernel_size[0]; - int64_t k1 = kernel_size[1]; - int64_t k2 = kernel_size[2]; - - int64_t s0 = stride[0]; - int64_t s1 = stride[1]; - int64_t s2 = stride[2]; - - int64_t d0 = dilation[0]; - int64_t d1 = dilation[1]; - int64_t d2 = dilation[2]; - - T max_value = 0; - int64_t max_index = -1; - - int64_t size12 = input_sizes[1] * input_sizes[2]; - - for (int64_t i0 = (s0 * o0) - padding[0]; - i0 < (s0 * o0 - padding[0] + k0 * d0) && i0 < input_sizes[0]; - i0 += d0) { - if (i0 < 0) { - continue; - } - int64_t offset0 = input_strides[0] * i0; - - for (int64_t i1 = (s1 * o1) - padding[1]; - i1 < (s1 * o1 - padding[1] + k1 * d1) && i1 < input_sizes[1]; - i1 += d1) { - if (i1 < 0) { - continue; - } - int64_t offset1 = input_strides[1] * i1; - - for (int64_t i2 = (s2 * o2) - padding[2]; - i2 < (s2 * o2 - padding[2] + k2 * d2) && i2 < input_sizes[2]; - i2 += d2) { - if (i2 < 0) { - continue; + constant int32_t* input_sizes, + constant int32_t* input_strides, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation, + bool return_indices) { + auto bounds0 = get_input_iter_bounds<0>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds1 = get_input_iter_bounds<1>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds2 = get_input_iter_bounds<2>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + + auto d0 = dilation[0]; + auto d1 = dilation[1]; + auto d2 = dilation[2]; + + T max_value = input + [input_strides[0] * bounds0.start + input_strides[1] * bounds1.start + + input_strides[2] * bounds2.start]; + auto size12 = input_sizes[1] * input_sizes[2]; + auto max_index = + bounds0.start * size12 + bounds1.start * input_sizes[2] + bounds2.start; + + for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) { + auto offset0 = input_strides[0] * i0; + + for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) { + auto offset1 = input_strides[1] * i1; + + for (auto i2 = bounds2.start; i2 < bounds2.end; i2 += d2) { + auto offset2 = input_strides[2] * i2; + auto input_value = input[offset0 + offset1 + offset2]; + bool is_greater = input_value > max_value; + + max_value = is_greater ? input_value : max_value; + + if (return_indices) { + auto input_index = i0 * size12 + i1 * input_sizes[2] + i2; + max_index = is_greater ? input_index : max_index; } - int64_t offset2 = input_strides[2] * i2; - - const T input_value = input[offset0 + offset1 + offset2]; - int64_t input_index = i0 * size12 + i1 * input_sizes[2] + i2; - - T new_max_value = (max_index == -1 || input_value > max_value) - ? input_value - : max_value; - int64_t new_max_index = (max_index == -1 || input_value > max_value) - ? input_index - : max_index; - - max_value = new_max_value; - max_index = new_max_index; } } } - *output = max_value; - *indices = max_index; + if (return_indices) { + *indices = max_index; + } } struct PoolOffsets { - int64_t output; - int64_t indices; - int64_t input_leading; + int32_t output; + int32_t indices; + int32_t input_leading; PoolOffsets() : output(0), indices(0), input_leading(0) {} }; @@ -98,30 +100,35 @@ struct PoolOffsets { // calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for // the leading dim indices, `input[N, C]`. Optionally, keep track of the output // pooling dimension indices, `[d, h , w]`. -PoolOffsets find_pool_offsets( - constant int64_t* output_sizes, - constant int64_t* output_strides, - constant int64_t* indices_strides, - constant int64_t* input_strides, - device int64_t* work_pooling_dim_indices, - int32_t dims, +// NOTE: This is templated per number of dimensions so that the compiler can +// unroll the loop, giving better performance. +template +PoolOffsets find_pool_offsets_dim_specific( + constant int32_t* output_sizes, + constant int32_t* output_strides, + constant int32_t* indices_strides, + constant int32_t* input_strides, + int32_t pooling_dim_indices[3], int32_t leading_dims, + bool return_indices, uint tid) { - int64_t output_idx = static_cast(tid); + auto output_idx = static_cast(tid); PoolOffsets offsets; - for (int64_t dim = dims - 1; dim >= 0; dim--) { - int64_t dim_idx = output_idx % (output_sizes[dim]); + for (auto dim = dims - 1; dim >= 0; dim--) { + auto dim_idx = output_idx % (output_sizes[dim]); offsets.output += output_strides[dim] * dim_idx; - offsets.indices += indices_strides[dim] * dim_idx; + if (return_indices) { + offsets.indices += indices_strides[dim] * dim_idx; + } if (dim < leading_dims) { offsets.input_leading += input_strides[dim] * dim_idx; } else { // Keep track of pooling dimension indices of the output element, so we // can use them in the input iteration later on. - if (work_pooling_dim_indices != nullptr) { - work_pooling_dim_indices[dim - leading_dims] = dim_idx; + if (pooling_dim_indices != nullptr) { + pooling_dim_indices[dim - leading_dims] = dim_idx; } } output_idx = output_idx / output_sizes[dim]; @@ -130,45 +137,76 @@ PoolOffsets find_pool_offsets( return offsets; } +PoolOffsets find_pool_offsets( + constant int32_t* output_sizes, + constant int32_t* output_strides, + constant int32_t* indices_strides, + constant int32_t* input_strides, + int32_t pooling_dim_indices[3], + int32_t dims, + int32_t leading_dims, + bool return_indices, + uint tid) { + switch (dims) { + case 5: + return find_pool_offsets_dim_specific<5>( + output_sizes, + output_strides, + indices_strides, + input_strides, + pooling_dim_indices, + leading_dims, + return_indices, + tid); + case 4: + return find_pool_offsets_dim_specific<4>( + output_sizes, + output_strides, + indices_strides, + input_strides, + pooling_dim_indices, + leading_dims, + return_indices, + tid); + } +} + // Kernel computes one element of the output per kernel call. template kernel void max_pool( - constant void* input_ [[buffer(0)]], - device void* output_ [[buffer(1)]], - device void* indices_ [[buffer(2)]], - device int64_t* work_pooling_dim_indices_ [[buffer(3)]], - constant PoolingParams<5>& params [[buffer(4)]], + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant PoolingParams<5>& params [[buffer(3)]], uint tid [[thread_position_in_grid]]) { - int32_t pooling_dims = params.pooling_dims; - int32_t dims = params.dims; - constant int64_t* input_sizes = params.input_sizes.data(); - constant int64_t* input_strides = params.input_strides.data(); - constant int64_t* output_sizes = params.output_sizes.data(); - constant int64_t* output_strides = params.output_strides.data(); - constant int64_t* indices_strides = params.indices_strides.data(); - constant int64_t* kernel_size = params.kernel_size.data(); - constant int64_t* stride = params.stride.data(); - constant int64_t* padding = params.padding.data(); - constant int64_t* dilation = params.dilation.data(); - - int32_t leading_dims = dims - pooling_dims; - constant T* input = reinterpret_cast(input_); - device T* output = reinterpret_cast(output_); - device int64_t* indices = reinterpret_cast(indices_); + bool return_indices = params.return_indices; + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto input_sizes = params.input_sizes.data(); + auto input_strides = params.input_strides.data(); + auto output_sizes = params.output_sizes.data(); + auto output_strides = params.output_strides.data(); + auto indices_strides = params.indices_strides.data(); + auto kernel_size = params.kernel_size.data(); + auto stride = params.stride.data(); + auto padding = params.padding.data(); + auto dilation = params.dilation.data(); + + auto leading_dims = dims - pooling_dims; // This buffer keeps track of the pooling dimension indices of this thread's // element of the output. We need to fill it with the proper values below. - device int64_t* work_pooling_dim_indices = - work_pooling_dim_indices_ + tid * pooling_dims; + int32_t pooling_dim_indices[3]; PoolOffsets offsets = find_pool_offsets( output_sizes, output_strides, indices_strides, input_strides, - work_pooling_dim_indices, + pooling_dim_indices, dims, leading_dims, + return_indices, tid); output += offsets.output; @@ -181,11 +219,12 @@ kernel void max_pool( indices, input_sizes + leading_dims, input_strides + leading_dims, - work_pooling_dim_indices, + pooling_dim_indices, kernel_size, stride, padding, - dilation); + dilation, + return_indices); } // Finds the element in the grad input which corresponds to the index into the @@ -195,15 +234,15 @@ void max_pool_backward_impl( device AtomicType_t* grad_input, T grad_output_element, int32_t input_index, - constant int64_t* grad_input_sizes, - constant int64_t* grad_input_strides, + constant int32_t* grad_input_sizes, + constant int32_t* grad_input_strides, int32_t grad_input_leading_offset, int32_t pooling_dims) { int32_t size_prod = 1; int32_t pool_offset = 0; - for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) { - int32_t next_size_prod = grad_input_sizes[dim] * size_prod; + for (auto dim = pooling_dims - 1; dim >= 0; dim--) { + auto next_size_prod = grad_input_sizes[dim] * size_prod; pool_offset += grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod); size_prod *= grad_input_sizes[dim]; @@ -221,15 +260,15 @@ kernel void max_pool_backward( constant int64_t* indices [[buffer(2)]], constant PoolingBackwardParams<5>& params [[buffer(3)]], uint tid [[thread_position_in_grid]]) { - int32_t pooling_dims = params.pooling_dims; - int32_t dims = params.dims; - constant int64_t* grad_input_sizes = params.grad_input_sizes.data(); - constant int64_t* grad_input_strides = params.grad_input_strides.data(); - constant int64_t* grad_output_sizes = params.grad_output_sizes.data(); - constant int64_t* grad_output_strides = params.grad_output_strides.data(); - constant int64_t* indices_strides = params.indices_strides.data(); + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto grad_input_sizes = params.grad_input_sizes.data(); + auto grad_input_strides = params.grad_input_strides.data(); + auto grad_output_sizes = params.grad_output_sizes.data(); + auto grad_output_strides = params.grad_output_strides.data(); + auto indices_strides = params.indices_strides.data(); - int32_t leading_dims = dims - pooling_dims; + auto leading_dims = dims - pooling_dims; PoolOffsets offsets = find_pool_offsets( grad_output_sizes, @@ -239,6 +278,7 @@ kernel void max_pool_backward( nullptr, dims, leading_dims, + /*return_indices=*/true, tid); max_pool_backward_impl( @@ -253,11 +293,10 @@ kernel void max_pool_backward( #define REGISTER_MAX_POOL_OP(DTYPE) \ template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool( \ - constant void* input_ [[buffer(0)]], \ - device void* output_ [[buffer(1)]], \ - device void* indices_ [[buffer(2)]], \ - device int64_t* work_pooling_dim_indices_ [[buffer(3)]], \ - constant PoolingParams<5>& params [[buffer(4)]], \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant PoolingParams<5>& params [[buffer(3)]], \ uint tid [[thread_position_in_grid]]); #define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \ diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index efc77360bb993..d5e6500194f8a 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -252,22 +252,20 @@ static void pool2d_template(const Tensor& input, } } -static std::vector copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) { - std::vector b; - if (a.size() == 1) { - b.assign(pooling_dims, a[0]); - } else { - b.assign(a.data(), a.data() + pooling_dims); +static std::vector copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) { + std::vector b(pooling_dims); + for (const auto dim : c10::irange(pooling_dims)) { + b[dim] = safe_downcast(a[a.size() == 1 ? 0 : dim]); } return b; } using PoolSizes = std::tuple, - std::vector, - std::vector, - std::vector, - std::vector>; + std::vector, + std::vector, + std::vector, + std::vector>; static PoolSizes process_pool_sizes(const Tensor& input, IntArrayRef kernel_size, @@ -368,7 +366,7 @@ static PoolSizes process_pool_sizes(const Tensor& input, } static void max_pool_with_indices_out_mps_template(const Tensor& output, - const Tensor& indices, + const std::optional& indices_opt, const Tensor& input, IntArrayRef _kernel_size, IntArrayRef _stride, @@ -379,10 +377,14 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output, const std::string& op_name) { auto [dims, output_size, kernel_size, stride, padding, dilation] = process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name); + const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt)); + const bool return_indices = indices.defined(); const auto memory_format = input.suggest_memory_format(); output.resize_(output_size, memory_format); - indices.resize_(output_size, memory_format); + if (return_indices) { + indices.resize_(output_size, memory_format); + } auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build(); @@ -395,33 +397,33 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output, params.dims = dims; params.pooling_dims = pooling_dims; - memcpy(params.input_sizes.data(), input.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.input_strides.data(), input.strides().data(), dims * sizeof(int64_t)); - memcpy(params.output_strides.data(), output.strides().data(), dims * sizeof(int64_t)); - memcpy(params.output_sizes.data(), output.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t)); - memcpy(params.indices_sizes.data(), indices.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int64_t)); + params.return_indices = return_indices; + + for (const auto dim : c10::irange(dims)) { + params.input_sizes[dim] = safe_downcast(input.size(dim)); + params.input_strides[dim] = safe_downcast(input.stride(dim)); + params.output_sizes[dim] = safe_downcast(output.size(dim)); + params.output_strides[dim] = safe_downcast(output.stride(dim)); + if (return_indices) { + params.indices_sizes[dim] = safe_downcast(indices.size(dim)); + params.indices_strides[dim] = safe_downcast(indices.stride(dim)); + } + } + + memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int32_t)); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_" + scalarToMetalTypeString(input)); - // Each thread needs to keep track of the indices into the pooling - // dimensions for the element of the output that it calculates. In other - // words, if the thread calculates `output[N, C, d, h, w]` for a 3D pool, - // the kernel needs to keep track of the indices `[d, h, w]`. So we create - // a device-side buffer for the threads to store these indices. - id work_pooling_dim_indices = [[device newBufferWithLength:numThreads * pooling_dims * sizeof(int64_t) - options:0] autorelease]; - getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input}); [computeEncoder setComputePipelineState:maxPoolPSO]; - mtl_setArgs(computeEncoder, input, output, indices, work_pooling_dim_indices, params); + mtl_setArgs( + computeEncoder, input, output, return_indices ? std::optional(indices) : std::nullopt, params); mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads); getMPSProfiler().endProfileKernel(maxPoolPSO); @@ -454,11 +456,14 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input, params.dims = dims; params.pooling_dims = pooling_dims; - memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t)); - memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t)); - memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t)); + + for (const auto dim : c10::irange(dims)) { + params.grad_input_sizes[dim] = safe_downcast(grad_input.size(dim)); + params.grad_input_strides[dim] = safe_downcast(grad_input.stride(dim)); + params.grad_output_sizes[dim] = safe_downcast(grad_output.size(dim)); + params.grad_output_strides[dim] = safe_downcast(grad_output.stride(dim)); + params.indices_strides[dim] = safe_downcast(indices.stride(dim)); + } dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { From ced5cf042de1d4b573f258c9f770581d9574b990 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Jul 2025 20:58:34 +0000 Subject: [PATCH 208/457] Revert "Cleanup old caffe2 scripts (#158475)" This reverts commit 94d7f0c1ef9a4cb4db0eb5d6b1ffc55941cbeab1. Reverted https://github.com/pytorch/pytorch/pull/158475 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158475#issuecomment-3085447409)) --- .github/workflows/pull.yml | 15 ++ scripts/README.md | 39 +++++ scripts/add_apache_header.sh | 1 + scripts/apache_header.txt | 15 ++ scripts/apache_python.txt | 14 ++ scripts/build_android.sh | 189 +++++++++++++++++++++++ scripts/build_android_gradle.sh | 102 ++++++++++++ scripts/build_host_protoc.sh | 59 +++++++ scripts/build_ios.sh | 155 +++++++++++++++++++ scripts/build_local.sh | 82 ++++++++++ scripts/build_mobile.sh | 107 +++++++++++++ scripts/build_pytorch_android.sh | 51 ++++++ scripts/build_raspbian.sh | 44 ++++++ scripts/build_tegra_x1.sh | 51 ++++++ scripts/build_tizen.sh | 118 ++++++++++++++ scripts/build_windows.bat | 80 ++++++++++ scripts/diagnose_protobuf.py | 92 +++++++++++ scripts/fbcode-dev-setup/ccache_setup.sh | 92 +++++++++++ scripts/get_python_cmake_flags.py | 24 +++ scripts/proto.ps1 | 18 +++ scripts/remove_apache_header.sh | 13 ++ scripts/temp.sh | 7 + scripts/xcode_build.rb | 76 +++++++++ 23 files changed, 1444 insertions(+) create mode 100755 scripts/add_apache_header.sh create mode 100644 scripts/apache_header.txt create mode 100644 scripts/apache_python.txt create mode 100755 scripts/build_android.sh create mode 100755 scripts/build_android_gradle.sh create mode 100755 scripts/build_host_protoc.sh create mode 100755 scripts/build_ios.sh create mode 100755 scripts/build_local.sh create mode 100755 scripts/build_mobile.sh create mode 100755 scripts/build_pytorch_android.sh create mode 100755 scripts/build_raspbian.sh create mode 100755 scripts/build_tegra_x1.sh create mode 100755 scripts/build_tizen.sh create mode 100644 scripts/build_windows.bat create mode 100644 scripts/diagnose_protobuf.py create mode 100755 scripts/fbcode-dev-setup/ccache_setup.sh create mode 100644 scripts/get_python_cmake_flags.py create mode 100644 scripts/proto.ps1 create mode 100755 scripts/remove_apache_header.sh create mode 100755 scripts/temp.sh create mode 100644 scripts/xcode_build.rb diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index be0bdc527cc11..59a7265173800 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -315,6 +315,21 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit + linux-jammy-py3-clang18-mobile-build: + name: linux-jammy-py3-clang18-mobile-build + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3-clang12-mobile-build + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan + build-generates-artifacts: false + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 1 }, + ]} + secrets: inherit + linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml diff --git a/scripts/README.md b/scripts/README.md index 367e7261f6a60..a1c5ae5f93e67 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1 +1,40 @@ This directory contains the useful tools. + + +## build_android.sh +This script is to build PyTorch/Caffe2 library for Android. Take the following steps to start the build: + +- set ANDROID_NDK to the location of ndk + +```bash +export ANDROID_NDK=YOUR_NDK_PATH +``` + +- run build_android.sh +```bash +#in your PyTorch root directory +bash scripts/build_android.sh +``` +If succeeded, the libraries and headers would be generated to build_android/install directory. You can then copy these files from build_android/install to your Android project for further usage. + +You can also override the cmake flags via command line, e.g., following command will also compile the executable binary files: +```bash +bash scripts/build_android.sh -DBUILD_BINARY=ON +``` + +## build_ios.sh +This script is to build PyTorch/Caffe2 library for iOS, and can only be performed on macOS. Take the following steps to start the build: + +- Install Xcode from App Store, and configure "Command Line Tools" properly on Xcode. +- Install the dependencies: + +```bash +brew install cmake automake libtool +``` + +- run build_ios.sh +```bash +#in your PyTorch root directory +bash scripts/build_ios.sh +``` +If succeeded, the libraries and headers would be generated to build_ios/install directory. You can then copy these files to your Xcode project for further usage. diff --git a/scripts/add_apache_header.sh b/scripts/add_apache_header.sh new file mode 100755 index 0000000000000..a29a059d2d033 --- /dev/null +++ b/scripts/add_apache_header.sh @@ -0,0 +1 @@ +cat apache_header.txt $1 > _add_apache_header.txt && mv _add_apache_header.txt $1 diff --git a/scripts/apache_header.txt b/scripts/apache_header.txt new file mode 100644 index 0000000000000..b4eff258eb04d --- /dev/null +++ b/scripts/apache_header.txt @@ -0,0 +1,15 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ diff --git a/scripts/apache_python.txt b/scripts/apache_python.txt new file mode 100644 index 0000000000000..bc104d8845154 --- /dev/null +++ b/scripts/apache_python.txt @@ -0,0 +1,14 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## diff --git a/scripts/build_android.sh b/scripts/build_android.sh new file mode 100755 index 0000000000000..43f11b86828d4 --- /dev/null +++ b/scripts/build_android.sh @@ -0,0 +1,189 @@ +#!/bin/bash +############################################################################## +# Example command to build the android target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for the Android platform +# using android-cmake. A few notes: +# +# (1) This build also does a host build for protobuf. You will need autoconf +# to carry out this. If autoconf is not possible, you will need to provide +# a pre-built protoc binary that is the same version as the protobuf +# version under third_party. +# If you are building on Mac, you might need to install autotool and +# libtool. The easiest way is via homebrew: +# brew install automake +# brew install libtool +# (2) You will need to have android ndk installed. The current script assumes +# that you set ANDROID_NDK to the location of ndk. +# (3) The toolchain and the build target platform can be specified with the +# cmake arguments below. For more details, check out android-cmake's doc. + +set -e + +# Android specific flags +if [ -z "$ANDROID_ABI" ]; then + ANDROID_ABI="armeabi-v7a with NEON" +fi +ANDROID_NATIVE_API_LEVEL="21" +echo "Build with ANDROID_ABI[$ANDROID_ABI], ANDROID_NATIVE_API_LEVEL[$ANDROID_NATIVE_API_LEVEL]" + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" +if [ -z "$ANDROID_NDK" ]; then + echo "ANDROID_NDK not set; please set it to the Android NDK directory" + exit 1 +fi + +if [ ! -d "$ANDROID_NDK" ]; then + echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?" + exit 1 +fi + +if [ -z "$PYTHON" ]; then + PYTHON=python + PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') + if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then + echo "Default python executable is Python-2, trying to use python3 alias" + PYTHON=python3 + fi +fi + +ANDROID_NDK_PROPERTIES="$ANDROID_NDK/source.properties" +[ -f "$ANDROID_NDK_PROPERTIES" ] && ANDROID_NDK_VERSION=$(sed -n 's/^Pkg.Revision[^=]*= *\([0-9]*\)\..*$/\1/p' "$ANDROID_NDK_PROPERTIES") + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" +echo "Caffe2 path: $CAFFE2_ROOT" +echo "Using Android NDK at $ANDROID_NDK" +echo "Android NDK version: $ANDROID_NDK_VERSION" + +CMAKE_ARGS=() + +# Build PyTorch mobile +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Use android-cmake to build Android project from CMake. +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") + +if [ -z "$BUILD_MOBILE_BENCHMARK" ]; then + BUILD_MOBILE_BENCHMARK=0 +fi + +if [ -z "$BUILD_MOBILE_TEST" ]; then + BUILD_MOBILE_TEST=0 +fi +# Don't build artifacts we don't need +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") + +# If there exists env variable and it equals to 0, build full jit interpreter. +# Default behavior is to build lite interpreter +# cmd: BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh +if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +fi +if [ "${TRACING_BASED}" == 1 ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK") +CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST") +CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") +CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") +if (( "${ANDROID_NDK_VERSION:-0}" < 18 )); then + CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=gcc") +else + CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=clang") +fi +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_OPENMP=OFF") +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# Android specific flags +CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") +CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") +CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") +CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") +if [ "${ANDROID_STL_SHARED:-}" == '1' ]; then + CMAKE_ARGS+=("-DANDROID_STL=c++_shared") +fi +if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then + CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") +fi + +if [ -n "${USE_VULKAN}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN=ON") + if [ -n "${USE_VULKAN_FP16_INFERENCE}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN_FP16_INFERENCE=ON") + fi + if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then + CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") + fi +fi + +# Use-specified CMake arguments go last to allow overriding defaults +CMAKE_ARGS+=($@) + +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/__cplusplus >= 201703L/0/" third_party/pocketfft/pocketfft_hdronly.h +fi + +# Now, actually build the Android target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi + +echo "Will install headers and libs to $INSTALL_PREFIX for further Android project usage." +cmake --build . --target install -- "-j${MAX_JOBS}" +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Android project directory." diff --git a/scripts/build_android_gradle.sh b/scripts/build_android_gradle.sh new file mode 100755 index 0000000000000..fc27c5dd2516b --- /dev/null +++ b/scripts/build_android_gradle.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -eux -o pipefail + +env +echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" + +export ANDROID_NDK_HOME=/opt/ndk +export ANDROID_NDK=/opt/ndk +export ANDROID_HOME=/opt/android/sdk + +# Must be in sync with GRADLE_VERSION in docker image for android +# https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh#L155 +export GRADLE_VERSION=6.8.3 +export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION +export GRADLE_PATH=$GRADLE_HOME/bin/gradle + +# touch gradle cache files to prevent expiration +while IFS= read -r -d '' file +do + touch "$file" || true +done < <(find /var/lib/jenkins/.gradle -type f -print0) + +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/__cplusplus >= 201703L/0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h +fi + +export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties +rm -f $GRADLE_LOCAL_PROPERTIES +echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES +echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES +echo "cmake.dir=/usr/local" >> $GRADLE_LOCAL_PROPERTIES + +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# Run custom build script +if [[ "${BUILD_ENVIRONMENT}" == *-gradle-custom-build* ]]; then + # Install torch & torchvision - used to download & dump used ops from test model. + retry pip install torch torchvision --progress-bar off + + exec "$(dirname "${BASH_SOURCE[0]}")/../android/build_test_app_custom.sh" armeabi-v7a +fi + +# Run default build +BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include +BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib + +BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include +BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib + +BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include +BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib + +BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include +BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib + +PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main + +JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include +mkdir -p $JNI_INCLUDE_DIR + +JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs +mkdir -p $JNI_LIBS_DIR + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 +ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 + +if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 +ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a + +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a +fi + +GRADLE_PARAMS="-p android assembleRelease --debug --stacktrace" +if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then + GRADLE_PARAMS+=" -PABI_FILTERS=x86" +fi + +if [ -n "${GRADLE_OFFLINE:-}" ]; then + GRADLE_PARAMS+=" --offline" +fi + +$GRADLE_PATH $GRADLE_PARAMS + +find . -type f -name "*.a" -exec ls -lh {} \; + +while IFS= read -r -d '' file +do + echo + echo "$file" + ls -lah "$file" + zipinfo -l "$file" +done < <(find . -type f -name '*.aar' -print0) + +find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz diff --git a/scripts/build_host_protoc.sh b/scripts/build_host_protoc.sh new file mode 100755 index 0000000000000..cd37db3b31713 --- /dev/null +++ b/scripts/build_host_protoc.sh @@ -0,0 +1,59 @@ +#!/bin/bash +############################################################################## +# Build script to build the protoc compiler for the host platform. +############################################################################## +# This script builds the protoc compiler for the host platform, which is needed +# for any cross-compilation as we will need to convert the protobuf source +# files to cc files. +# +# --other-flags accepts flags that should be passed to cmake. Optional. +# +# After the execution of the file, one should be able to find the host protoc +# binary at build_host_protoc/bin/protoc. + +set -e + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_host_protoc"} +mkdir -p $BUILD_ROOT/build +cd $BUILD_ROOT/build + +CMAKE_ARGS=() +CMAKE_ARGS+=("-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT") +CMAKE_ARGS+=("-Dprotobuf_BUILD_TESTS=OFF") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +while true; do + case "$1" in + --other-flags) + shift; + CMAKE_ARGS+=("$@") + break ;; + "") + break ;; + *) + echo "Unknown option passed as argument: $1" + break ;; + esac +done + +# Use ccache if available (this path is where Homebrew installs ccache symlinks) +if [ "$(uname)" == 'Darwin' ] && [ -d /usr/local/opt/ccache/libexec ]; then + CMAKE_ARGS+=("-DCMAKE_C_COMPILER=/usr/local/opt/ccache/libexec/gcc") + CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=/usr/local/opt/ccache/libexec/g++") +fi + +cmake "$CAFFE2_ROOT/third_party/protobuf/cmake" ${CMAKE_ARGS[@]} + +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi +cmake --build . -- "-j${MAX_JOBS}" install diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh new file mode 100755 index 0000000000000..ad16cb940dcb8 --- /dev/null +++ b/scripts/build_ios.sh @@ -0,0 +1,155 @@ +#!/bin/bash -xe +############################################################################## +# Example command to build the iOS target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for the iOS platform +# using ios-cmake. This is very similar to the android-cmake - see +# build_android.sh for more details. + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +if [ -z "$PYTHON" ]; then + PYTHON=python + PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') + if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then + echo "Default python executable is Python-2, trying to use python3 alias" + PYTHON=python3 + fi +fi + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" +echo "Caffe2 path: $CAFFE2_ROOT" + +CMAKE_ARGS=() + +# Build PyTorch mobile +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# bitcode +if [ "${ENABLE_BITCODE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") +fi + +# Use ios-cmake to build iOS project from CMake. +# This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and +# CMAKE_CXX_COMPILER to /usr/bin/g++. In order to use ccache (if it is available) we +# must override these variables via CMake arguments. +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$CAFFE2_ROOT/cmake/iOS.cmake") +if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then + CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec +fi +if [ -d "$CCACHE_WRAPPER_PATH" ]; then + CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") + CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") +fi + +# IOS_PLATFORM controls type of iOS platform (see ios-cmake) +if [ -n "${IOS_PLATFORM:-}" ]; then + CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") + if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then + # enable bitcode by default for watchos + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") + # disable the QNNPACK + CMAKE_ARGS+=("-DUSE_PYTORCH_QNNPACK=OFF") + fi +else + # IOS_PLATFORM is not set, default to OS, which builds iOS. + CMAKE_ARGS+=("-DIOS_PLATFORM=OS") +fi + +if [ -n "${IOS_ARCH:-}" ]; then + CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") +fi + +if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +fi +if [ "${TRACING_BASED}" == 1 ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +CMAKE_ARGS+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") + +# Don't build binaries or tests (only the library) +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") +CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") + +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_NUMPY=OFF") +CMAKE_ARGS+=("-DUSE_NNPACK=OFF") +CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") + +# Metal +if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then + CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") +fi + +# Core ML +if [ "${USE_COREML_DELEGATE}" == "1" ]; then + CMAKE_ARGS+=("-DUSE_COREML_DELEGATE=ON") +fi + +# pthreads +CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") +CMAKE_ARGS+=("-DCMAKE_HAVE_THREADS_LIBRARY=1") +CMAKE_ARGS+=("-DCMAKE_USE_PTHREADS_INIT=1") + +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# enable ARC +CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fobjc-arc") + +# Now, actually build the iOS target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DBUILD_SHARED_LIBS=OFF \ + ${CMAKE_ARGS[@]} \ + $@ + +cmake --build . -- "-j$(sysctl -n hw.ncpu)" + +# copy headers and libs to install directory +echo "Will install headers and libs to $INSTALL_PREFIX for further Xcode project usage." +make install +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Xcode project directory." diff --git a/scripts/build_local.sh b/scripts/build_local.sh new file mode 100755 index 0000000000000..b843671501256 --- /dev/null +++ b/scripts/build_local.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# +############################################################################## +# Example command to build Caffe2 +############################################################################## +# + +set -ex + +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +CMAKE_ARGS=() + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Use ccache if available (this path is where Homebrew installs ccache symlinks) +if [ "$(uname)" == 'Darwin' ]; then + if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then + CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec + fi + if [ -d "$CCACHE_WRAPPER_PATH" ]; then + CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") + CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") + fi +fi + +# Use special install script with Anaconda +if [ -n "${USE_ANACONDA}" ]; then + export SKIP_CONDA_TESTS=1 + export CONDA_INSTALL_LOCALLY=1 + "${ROOT_DIR}/scripts/build_anaconda.sh" "$@" +else + # Make sure that pyyaml is installed for the codegen of building Aten to work + if [[ -n "$(python -c 'import yaml' 2>&1)" ]]; then + echo "Installing pyyaml with pip at $(which pip)" + pip install --user pyyaml + fi + + # Make sure that typing is installed for the codegen of building Aten to work + if [[ -n "$(python -c 'import typing' 2>&1)" ]]; then + echo "Installing typing with pip at $(which pip)" + pip install --user typing + fi + + # Build protobuf compiler from third_party if configured to do so + if [ -n "${USE_HOST_PROTOC:-}" ]; then + echo "USE_HOST_PROTOC is set; building protoc before building Caffe2..." + "$CAFFE2_ROOT/scripts/build_host_protoc.sh" + CUSTOM_PROTOC_EXECUTABLE="$CAFFE2_ROOT/build_host_protoc/bin/protoc" + echo "Built protoc $("$CUSTOM_PROTOC_EXECUTABLE" --version)" + CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CUSTOM_PROTOC_EXECUTABLE") + fi + + # We are going to build the target into build. + BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} + mkdir -p "$BUILD_ROOT" + cd "$BUILD_ROOT" + echo "Building Caffe2 in: $BUILD_ROOT" + + cmake "$CAFFE2_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" \ + "$@" + + # Determine the number of CPUs to build with. + # If the `CAFFE_MAKE_NCPUS` variable is not specified, use them all. + if [ -n "${MAX_JOBS}" ]; then + CAFFE_MAKE_NCPUS="$MAX_JOBS" + elif [ -n "${CAFFE_MAKE_NCPUS}" ]; then + CAFFE_MAKE_NCPUS="$CAFFE_MAKE_NCPUS" + elif [ "$(uname)" == 'Darwin' ]; then + CAFFE_MAKE_NCPUS="$(sysctl -n hw.ncpu)" + else + CAFFE_MAKE_NCPUS="$(nproc)" + fi + + # Now, actually build the target. + cmake --build . -- "-j$CAFFE_MAKE_NCPUS" +fi diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh new file mode 100755 index 0000000000000..7b1995a61ebc7 --- /dev/null +++ b/scripts/build_mobile.sh @@ -0,0 +1,107 @@ +#!/bin/bash +############################################################################## +# Example command to build the mobile target. +############################################################################## +# +# This script shows how one can build a libtorch library optimized for mobile +# devices using host toolchain. + +set -e + +export BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Caffe2 path: $CAFFE2_ROOT" + +CMAKE_ARGS=() +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") +CMAKE_ARGS+=("-DPython_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") +CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") + +# custom build with selected ops +if [ -n "${SELECTED_OP_LIST}" ]; then + SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" + echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" + if [ ! -r ${SELECTED_OP_LIST} ]; then + echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." + exit 1 + fi + CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") +fi + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Don't build artifacts we don't need +CMAKE_ARGS+=("-DBUILD_TEST=OFF") +CMAKE_ARGS+=("-DBUILD_BINARY=OFF") + +# If there exists env variable and it equals to 1, build lite interpreter. +# Default behavior is to build full jit interpreter. +# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_mobile.sh +if [ "x${BUILD_LITE_INTERPRETER}" == "x1" ]; then + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") +else + CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") +fi +if [ "x${TRACING_BASED}" == "x1" ]; then + CMAKE_ARGS+=("-DTRACING_BASED=ON") +else + CMAKE_ARGS+=("-DTRACING_BASED=OFF") +fi + +# Lightweight dispatch bypasses the PyTorch Dispatcher. +if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") + CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") +else + CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") +fi + +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_ROCM=OFF") +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_ITT=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_OPENMP=OFF") +CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") +CMAKE_ARGS+=("-DUSE_NNPACK=OFF") +CMAKE_ARGS+=("-DUSE_NUMPY=OFF") +CMAKE_ARGS+=("-DUSE_BLAS=OFF") + +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# Use-specified CMake arguments go last to allow overriding defaults +CMAKE_ARGS+=("$@") + +# Now, actually build the Android target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi + +echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." +cmake --build . --target install -- "-j${MAX_JOBS}" +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/build_pytorch_android.sh b/scripts/build_pytorch_android.sh new file mode 100755 index 0000000000000..7b80965e34b5c --- /dev/null +++ b/scripts/build_pytorch_android.sh @@ -0,0 +1,51 @@ +#!/bin/bash +set -eux + +############################################################################## +# Master script to build PyTorch Android library with Java bindings. +############################################################################## +# Example usage: +# - Build default AARs: +# scripts/build_pytorch_android.sh +# +# - Build for specific ABI(s): +# scripts/build_pytorch_android.sh armeabi-v7a +# scripts/build_pytorch_android.sh arm64-v8a,x86,x86_64 +# +# Script's workflow: +# 1. Builds libtorch for android for specified android abisi (by default for all 4). +# Custom list of android abis can be specified as a bash argument as comma separated list. +# For example just for testing on android x86 emulator we need only x86 build. +# ./scripts/build_pytorch_android.sh x86 +# 2. Creates symbolic links to android/pytorch_android/src/main/jniLibs/${abi} for libtorch build output, +# android/pytorch_android/src/main/cpp/libtorch_include/${abi} for headers. +# 3. Runs pyotrch_android gradle build: +# gradle assembleRelease + +PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" +PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android + +echo "PYTORCH_DIR:$PYTORCH_DIR" + +source "$PYTORCH_ANDROID_DIR/common.sh" + +check_android_sdk +check_gradle +parse_abis_list "$@" +build_android + +# To set proxy for gradle add following lines to ./gradle/gradle.properties: +# systemProp.http.proxyHost=... +# systemProp.http.proxyPort=8080 +# systemProp.https.proxyHost=... +# systemProp.https.proxyPort=8080 + +if [ "$CUSTOM_ABIS_LIST" = true ]; then + # Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems + # with it when abiFilters are specified. + $GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR assembleRelease +else + $GRADLE_PATH -p $PYTORCH_ANDROID_DIR clean assembleRelease +fi + +find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah diff --git a/scripts/build_raspbian.sh b/scripts/build_raspbian.sh new file mode 100755 index 0000000000000..b1fe85926219e --- /dev/null +++ b/scripts/build_raspbian.sh @@ -0,0 +1,44 @@ +#!/bin/bash +############################################################################## +# Example command to build the Raspbian target. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for raspbian. The build +# is essentially much similar to a host build, with one additional change +# which is to specify -mfpu=neon for optimized speed. + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 raspbian into: $BUILD_ROOT" + +# obtain dependencies. +echo "Installing dependencies." +sudo apt-get install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + libpython-dev \ + python-pip \ + python-numpy \ + protobuf-compiler \ + python-protobuf +# python dependencies +sudo pip install hypothesis + +# Now, actually build the raspbian target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: you can add more dependencies above if you need libraries such as +# leveldb, lmdb, etc. +cmake "$CAFFE2_ROOT" \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=hard" \ + || exit 1 + +# Note: while Raspberry pi has 4 cores, running too many builds in parallel may +# cause out of memory errors so we will simply run -j 2 only. +make -j 2 || exit 1 diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh new file mode 100755 index 0000000000000..063e17dfe3514 --- /dev/null +++ b/scripts/build_tegra_x1.sh @@ -0,0 +1,51 @@ +#!/bin/bash +############################################################################## +# Example command to build Caffe2 on Tegra X1. +############################################################################## +# +# This script shows how one can build a Caffe2 binary for NVidia's TX1. +# The build script assumes that you have the most recent libraries installed +# via the JetPack toolkit available at +# https://developer.nvidia.com/embedded/jetpack +# and it assumes that we are starting from a fresh system after the jetpack +# installation. If you have already installed some of the dependencies, you +# may be able to skip quite a few of the apt-get installs. + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 raspbian into: $BUILD_ROOT" + +# obtain necessary dependencies +echo "Installing dependencies." +sudo apt-get install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + protobuf-compiler + +# obtain optional dependencies that are usually useful to have. +echo "Installing optional dependencies." +sudo apt-get install \ + libpython-dev \ + python-numpy \ + python-pip \ + python-protobuf + +# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that +# the one provided by apt-get is quite old so we install it via pip +sudo pip install hypothesis + +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# CUDA_USE_STATIC_CUDA_RUNTIME needs to be set to off so that opencv can be +# properly used. Otherwise, opencv will complain that opencv_dep_cudart cannot +# be found. +cmake "$CAFFE2_ROOT" -DCUDA_USE_STATIC_CUDA_RUNTIME=OFF \ + || exit 1 + +make -j 4 || exit 1 diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh new file mode 100755 index 0000000000000..2262a2503c1d0 --- /dev/null +++ b/scripts/build_tizen.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash +############################################################################## +# Example command to build the Tizen target (RPi3). +############################################################################## +# +# This script shows how one can build a Caffe2 binary for a Tizen device (RPi3). +# The build is essentially much similar to a host build, with one additional change +# which is to specify -mfpu=neon for optimized speed. + +setup_environment(){ +# The rootfs image for a Tizen target (RPi3)is located at the below webpage: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/tizen-unified_20170529.1/images/ +# If you do not have a Tizen device, Please, run qemu-arm-static and chroot command. +# $ sudo chroot ~/tizen-rootfs qemu-arm-static /usr/bin/bash + +CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" +echo "Caffe2 codebase root is: $CAFFE2_ROOT" +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} +mkdir -p $BUILD_ROOT +echo "Build Caffe2 Tizen into: $BUILD_ROOT" +} + +caffe2_lite_dep_packages(){ +# Obtain necessary dependencies +# You can set-up a rpm repository with zypper, yum, and dnf because Tizen +# software platform officially support rpm format such as Fedora, OpenSUSE. +# The official Tizen repository is as following: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ +echo "Installing dependencies." +sudo zypper install \ + make \ + strace \ + cmake \ + gcc* \ + binutils \ + glibc* \ + cpp \ + protobuf-devel \ + libstdc++* +} + +caffe2_lite_build(){ +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. +# If you have to disable a specific package due to a package absence +# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. +cmake .. \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DUSE_GFLAGS=OFF \ + -DUSE_GLOG=OFF -DUSE_NNPACK=OFF \ + -DRUN_HAVE_STD_REGEX=0 \ + -DRUN_HAVE_POSIX_REGEX=0 \ + -DHAVE_GNU_POSIX_REGEX=0 \ + -DUSE_MPI=OFF -DUSE_OPENMP=OFF \ + -DBUILD_PYTHON=OFF \ + -DUSE_GLOO=OFF \ + -DUSE_OPENCV=OFF \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ + || exit 1 + +make -j`nproc` || exit 1 +} + +caffe2_full_dep_packages(){ +# Obtain necessary dependencies +# You can set-up a rpm repository with zypper, yum, and dnf because Tizen +# software platform officially support rpm format such as Fedora, OpenSUSE. +# The official Tizen repository is as following: +# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ +echo "Installing dependencies." +sudo zypper install \ + cmake \ + libgflags-dev \ + libgoogle-glog-dev \ + libprotobuf-dev \ + protobuf-compiler + +# Obtain optional dependencies that are usually useful to have. +echo "Installing optional dependencies." +sudo zypper install \ + libpython-dev \ + python-numpy \ + python-pip \ + python-protobuf + +# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that +# the one provided by zypper is quite old so we install it via pip +sudo pip install hypothesis +} + +caffe2_full_build(){ +# Now, actually build the android target. +echo "Building caffe2" +cd $BUILD_ROOT + +# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. +# If you have to disable a specific package due to a package absence +# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. +cmake "$CAFFE2_ROOT" \ + -DCMAKE_VERBOSE_MAKEFILE=1 \ + -DUSE_CUDA=OFF \ + -DUSE_ITT=OFF \ + -DUSE_OPENCV=OFF \ + -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ + || exit 1 + +make -j`nproc` || exit 1 +} + +#### Main +# Setup a build environment to compile Caffe2 deeplearning framework in Tizen platform. +setup_environment +# There are two build options to support 'full' version and 'lite' version (by default). +caffe2_lite_dep_packages +caffe2_lite_build diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat new file mode 100644 index 0000000000000..60bfebad08c01 --- /dev/null +++ b/scripts/build_windows.bat @@ -0,0 +1,80 @@ +:: ############################################################################# +:: Example command to build on Windows. +:: ############################################################################# + +:: This script shows how one can build a Caffe2 binary for windows. + +@echo off +setlocal + +SET ORIGINAL_DIR=%cd% +SET CAFFE2_ROOT=%~dp0%.. + +if NOT DEFINED BUILD_BINARY ( + set BUILD_BINARY=OFF +) + +if NOT DEFINED BUILD_SHARED_LIBS ( + :: On CI, we test with BUILD_SHARED_LIBS=OFF. + :: By default, it will be BUILD_SHARED_LIBS=ON. + if NOT DEFINED BUILD_ENVIRONMENT ( + set BUILD_SHARED_LIBS=OFF + ) +) + +if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( + set CAFFE2_STATIC_LINK_CUDA=OFF +) + +if NOT DEFINED CMAKE_BUILD_TYPE ( + set CMAKE_BUILD_TYPE=Release +) + +if NOT DEFINED ONNX_NAMESPACE ( + set ONNX_NAMESPACE=onnx_c2 +) + +if NOT DEFINED TORCH_CUDA_ARCH_LIST ( + set TORCH_CUDA_ARCH_LIST=5.0 +) + +if NOT DEFINED USE_CUDA ( + set USE_CUDA=OFF +) + +if NOT DEFINED USE_OBSERVERS ( + set USE_OBSERVERS=OFF +) + +if NOT DEFINED MSVC_Z7_OVERRIDE ( + set MSVC_Z7_OVERRIDE=OFF +) + +if NOT DEFINED CMAKE_GENERATOR ( + set CMAKE_GENERATOR=Ninja +) + +set CMAKE_VERBOSE_MAKEFILE=1 + +:: Install pyyaml for Aten codegen +pip install pyyaml ninja + +echo CAFFE2_ROOT=%CAFFE2_ROOT% +echo CMAKE_GENERATOR=%CMAKE_GENERATOR% +echo CMAKE_BUILD_TYPE=%CMAKE_BUILD_TYPE% + +:: Set up cmake. We will skip building the test files right now. +pushd %CAFFE2_ROOT% +python tools\build_libtorch.py || goto :label_error +popd + +echo "Caffe2 built successfully" +cd %ORIGINAL_DIR% +endlocal +exit /b 0 + +:label_error +echo "Caffe2 building failed" +cd %ORIGINAL_DIR% +endlocal +exit /b 1 diff --git a/scripts/diagnose_protobuf.py b/scripts/diagnose_protobuf.py new file mode 100644 index 0000000000000..65af4618228db --- /dev/null +++ b/scripts/diagnose_protobuf.py @@ -0,0 +1,92 @@ +## @package diagnose_protobuf +# Module scripts.diagnose_protobuf +"""Diagnoses the current protobuf situation. + +Protocol buffer needs to be properly installed for Caffe2 to work, and +sometimes it is rather tricky. Specifically, we will need to have a +consistent version between C++ and python simultaneously. This is a +convenience script for one to quickly check if this is so on one's local +machine. + +Usage: + [set your environmental variables like PATH and PYTHONPATH] + python scripts/diagnose_protobuf.py +""" + +import os +import re +from subprocess import PIPE, Popen + + +# Get python protobuf version. +try: + import google.protobuf + + python_version = google.protobuf.__version__ + python_protobuf_installed = True +except ImportError: + print("DEBUG: cannot find python protobuf install.") + python_protobuf_installed = False + +if os.name == "nt": + protoc_name = "protoc.exe" +else: + protoc_name = "protoc" + +try: + p = Popen([protoc_name, "--version"], stdout=PIPE, stderr=PIPE) + out, err = p.communicate() +except: + print("DEBUG: did not find protoc binary.") + print("DEBUG: out: " + out) + print("DEBUG: err: " + err) + native_protobuf_installed = False +else: + if p.returncode: + print("DEBUG: protoc returned a non-zero return code.") + print("DEBUG: out: " + out) + print("DEBUG: err: " + err) + native_protobuf_installed = False + else: + tmp = re.search(r"\d\.\d\.\d", out) + if tmp: + native_version = tmp.group(0) + native_protobuf_installed = True + else: + print("DEBUG: cannot parse protoc version string.") + print("DEBUG: out: " + out) + native_protobuf_installed = False + +PYTHON_PROTOBUF_NOT_INSTALLED = """ +You have not installed python protobuf. Protobuf is needed to run caffe2. You +can install protobuf via pip or conda (if you are using anaconda python). +""" + +NATIVE_PROTOBUF_NOT_INSTALLED = """ +You have not installed the protoc binary. Protoc is needed to compile Caffe2 +protobuf source files. Depending on the platform you are on, you can install +protobuf via: + (1) Mac: using homebrew and do brew install protobuf. + (2) Linux: use apt and do apt-get install libprotobuf-dev + (3) Windows: install from source, or from the releases here: + https://github.com/google/protobuf/releases/ +""" + +VERSION_MISMATCH = f""" +Your python protobuf is of version {python_version} but your native protoc version is of +version {native_version}. This will cause the installation to produce incompatible +protobuf files. This is bad in general - consider installing the same version. +""" + +# Now, give actual recommendations +if not python_protobuf_installed: + print(PYTHON_PROTOBUF_NOT_INSTALLED) + +if not native_protobuf_installed: + print(NATIVE_PROTOBUF_NOT_INSTALLED) + +if python_protobuf_installed and native_protobuf_installed: + if python_version != native_version: + print(VERSION_MISMATCH) + else: + print("All looks good.") diff --git a/scripts/fbcode-dev-setup/ccache_setup.sh b/scripts/fbcode-dev-setup/ccache_setup.sh new file mode 100755 index 0000000000000..cb461bee2dd27 --- /dev/null +++ b/scripts/fbcode-dev-setup/ccache_setup.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# This script installs CCache with CUDA support. +# Example usage: +# ./ccache_setup.sh --path /installed/folder + +set -e +shopt -s expand_aliases + +# Setup the proxy +alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" + +# Parse options +path="$HOME/ccache" +force=false + +while [[ $# -gt 0 ]]; do + case "$1" in + --path) + shift + path="$1" + path=$(realpath "$path") + ;; + --force) # Force install + force=true + ;; + --help) + echo 'usage: ./ccache_setup.py --path /installed/folder [--force]' + exit 0 + ;; + *) + echo "Invalid option: $1" + exit 1 + ;; + esac + shift +done + +# Check whether you put nvcc in PATH +set +e +nvcc_path=$(which nvcc) +if [[ -z "$nvcc_path" ]]; then + nvcc_path="/usr/local/cuda/bin/nvcc" + export PATH="/usr/local/cuda/bin:$PATH" +fi +set -e +if [ ! -f "$nvcc_path" ] && ! $force; then + # shellcheck disable=SC2016 + echo 'nvcc is not detected in $PATH' + exit 1 +fi +echo "nvcc is detected at $nvcc_path" + +if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule + if $CUDA_NVCC_EXECUTABLE --version; then + if ! $force; then + echo "CCache with nvcc support is already installed at $CUDA_NVCC_EXECUTABLE, please add --force" + exit 0 + fi + fi +fi + +# Installing CCache +echo "CCache will be installed at $path" +if [ -e "$path" ]; then + mv --backup=t -T "$path" "${path}.old" +fi + +with_proxy git clone https://github.com/colesbury/ccache.git "$path" -b ccbin +cd "$path" +./autogen.sh +./configure +make install prefix="$path" + +mkdir -p "$path/lib" +mkdir -p "$path/cuda" +ln -sf "$path/bin/ccache" "$path/lib/cc" +ln -sf "$path/bin/ccache" "$path/lib/c++" +ln -sf "$path/bin/ccache" "$path/lib/gcc" +ln -sf "$path/bin/ccache" "$path/lib/g++" +ln -sf "$path/bin/ccache" "$path/cuda/nvcc" +"$path/bin/ccache" -M 25Gi + +# Make sure the nvcc wrapped in CCache is runnable +"$path/cuda/nvcc" --version +echo 'Congrats! The CCache with nvcc support is installed!' +echo -e "Please add the following lines to your bash init script:\\n" +echo "################ Env Var for CCache with CUDA support ################" +# shellcheck disable=SC2016 +echo 'export PATH="'"$path"'/lib:$PATH"' +echo 'export CUDA_NVCC_EXECUTABLE="'"$path"'/cuda/nvcc"' +echo '######################################################################' diff --git a/scripts/get_python_cmake_flags.py b/scripts/get_python_cmake_flags.py new file mode 100644 index 0000000000000..a49debcc884ad --- /dev/null +++ b/scripts/get_python_cmake_flags.py @@ -0,0 +1,24 @@ +## @package get_python_cmake_flags +# Module scripts.get_python_cmake_flags +############################################################################## +# Use this script to find your preferred python installation. +############################################################################## +# +# You can use the following to build with your preferred version of python +# if your installation is not being properly detected by CMake. +# +# mkdir -p build && cd build +# cmake $(python ../scripts/get_python_cmake_flags.py) .. +# make +# + + +import sys +import sysconfig + + +flags = [ + f"-DPython_EXECUTABLE:FILEPATH={sys.executable}", +] + +print(" ".join(flags), end="") diff --git a/scripts/proto.ps1 b/scripts/proto.ps1 new file mode 100644 index 0000000000000..a6bce82ff682d --- /dev/null +++ b/scripts/proto.ps1 @@ -0,0 +1,18 @@ +param( + [string]$protoc, + [string]$srcdir, + [string]$unprocessed, + [string]$processed, + [string]$out +) +$ErrorActionPreference = "Stop" +Get-Content $unprocessed | % {$_ -Replace "caffe2/proto/caffe2.proto", "caffe2.proto"} | Set-Content $processed +Add-Content -Path $processed -Value "option optimize_for = LITE_RUNTIME;`n" -NoNewline +$dir = (Get-Item $processed).DirectoryName + +copy $srcdir/caffe2/proto/caffe2.proto $srcdir/caffe2.proto +Add-Content -Path $srcdir/caffe2.proto -Value "option optimize_for = LITE_RUNTIME;`n" -NoNewline + +$processed = (Get-Item $processed).Name +$cmd = "$protoc -I${dir} --cpp_out=$out $processed" +Invoke-Expression $cmd diff --git a/scripts/remove_apache_header.sh b/scripts/remove_apache_header.sh new file mode 100755 index 0000000000000..97980bfbb0ef6 --- /dev/null +++ b/scripts/remove_apache_header.sh @@ -0,0 +1,13 @@ +if [[ "$1" == *.py ]]; then + apache_header="apache_python.txt" +else + apache_header="apache_header.txt" +fi +apache_lines=$(wc -l < "${apache_header}") +apache_md5=$(cat "${apache_header}" | md5) +header_md5=$(head -n ${apache_lines} $1 | md5) +if [ "${header_md5}" == "${apache_md5}" ]; then + keep_lines=$(($(wc -l < $1) - ${apache_lines})) + tail -n ${keep_lines} $1 > _remove_apache_header.txt + mv _remove_apache_header.txt $1 +fi diff --git a/scripts/temp.sh b/scripts/temp.sh new file mode 100755 index 0000000000000..18eb2b4733816 --- /dev/null +++ b/scripts/temp.sh @@ -0,0 +1,7 @@ +find ../caffe2 -name "*.py" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.h" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cc" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cpp" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.cu" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.mm" -exec ./remove_apache_header.sh {} \; +find ../caffe2 -name "*.m" -exec ./remove_apache_header.sh {} \; diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb new file mode 100644 index 0000000000000..0734167bdda11 --- /dev/null +++ b/scripts/xcode_build.rb @@ -0,0 +1,76 @@ +require 'optparse' +require 'xcodeproj' + +options = {} +option_parser = OptionParser.new do |opts| + opts.banner = 'Tools for building PyTorch iOS framework on MacOS' + opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| + options[:install] = value + } + opts.on('-x', '--xcodeproj_path ', 'path to the XCode project file') { |value| + options[:xcodeproj] = value + } + opts.on('-p', '--platform ', 'platform for the current build, OS or SIMULATOR') { |value| + options[:platform] = value + } +end.parse! +puts options.inspect + +install_path = File.expand_path(options[:install]) +if not Dir.exist? (install_path) + raise "path don't exist:#{install_path}!" +end +xcodeproj_path = File.expand_path(options[:xcodeproj]) +if not File.exist? (xcodeproj_path) + raise "path don't exist:#{xcodeproj_path}!" +end + +project = Xcodeproj::Project.open(xcodeproj_path) +target = project.targets.first #TestApp +header_search_path = ['$(inherited)', "#{install_path}/include"] +libraries_search_path = ['$(inherited)', "#{install_path}/lib"] +other_linker_flags = ['$(inherited)', "-all_load"] + +target.build_configurations.each do |config| + config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path + config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path + config.build_settings['OTHER_LDFLAGS'] = other_linker_flags + config.build_settings['ENABLE_BITCODE'] = 'No' +end + +# link static libraries +target.frameworks_build_phases.clear +libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libmicrokernels-prod.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a', 'libkineto.a'] +for lib in libs do + path = "#{install_path}/lib/#{lib}" + if File.exist?(path) + libref = project.frameworks_group.new_file(path) + target.frameworks_build_phases.add_file_reference(libref) + end +end +# link system frameworks +frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate', 'UIKit'] +if frameworks + frameworks.each do |framework| + path = "System/Library/Frameworks/#{framework}.framework" + framework_ref = project.frameworks_group.new_reference(path) + framework_ref.name = "#{framework}.framework" + framework_ref.source_tree = 'SDKROOT' + target.frameworks_build_phases.add_file_reference(framework_ref) + end +end +project.save + +sdk = nil +arch = nil +if options[:platform] == 'SIMULATOR' + sdk = 'iphonesimulator' + arch = 'arm64' +elsif options[:platform] == 'OS' + sdk = 'iphoneos' + arch = 'arm64' +else + raise "unsupported platform #{options[:platform]}" +end + +exec "xcodebuild clean build -project #{xcodeproj_path} -alltargets -sdk #{sdk} -configuration Release -arch #{arch}" From a00442421a14448f95fc28790325f941662d97f2 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 17 Jul 2025 21:05:25 +0000 Subject: [PATCH 209/457] [CI][TD] Enable TD on all test configs (#158163) I think the main one that was missing is dynamo_wrapped There's also slow and inductor, but the filter later for workflows stops TD from running on those anyways dynamo_wrapped is the second longest jobs for pull right now image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158163 Approved by: https://github.com/huydhn, https://github.com/ZainRizvi --- test/run_test.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 64d6067edc94a..c63fb64e8f05c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -36,7 +36,6 @@ TEST_CUDA, TEST_SAVE_XML, TEST_WITH_ASAN, - TEST_WITH_CROSSREF, TEST_WITH_ROCM, TEST_WITH_SLOW_GRADCHECK, ) @@ -1410,11 +1409,6 @@ def parse_args(): action="store_true", help="Enables removing tests based on TD", default=IS_CI - and ( - TEST_WITH_CROSSREF - or TEST_CONFIG == "distributed" - or TEST_CONFIG == "default" - ) and get_pr_number() is not None and not strtobool(os.environ.get("NO_TD", "False")) and not IS_MACOS From af6624023e4a9347d68db8517fad684a68b391a2 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 17 Jul 2025 09:48:30 -0700 Subject: [PATCH 210/457] [dynamo] Skip training flag check id already guarding on nn modules (#158492) This might help some legacy models that still have inline_inbuilt_nn_modules False for some reason. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158492 Approved by: https://github.com/StrongerXi --- torch/_dynamo/guards.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 33ddb4c303bc5..983aa2133874c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -700,6 +700,9 @@ def __init__( ] = {} self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self.serialization_mode = serialization_mode + self.guard_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) @@ -1841,7 +1844,9 @@ def NN_MODULE(self, guard: Guard): val = self.get(guard.name) if hasattr(val, "training"): assert istype(val.training, bool) - self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + if not self.guard_nn_modules: + # If guard_nn_modules is true, we will guard on the right set of guards + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) else: exc.unimplemented_v2( gb_type="Attempted to guard on uninitialized nn.Module", From 41b2c4d1196311ac619d6a025f0181e3977bbe8c Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 17 Jul 2025 19:06:56 +0000 Subject: [PATCH 211/457] Reduce random reads for offset metadata when calling torch.load under FakeTensorMode (#157931) We already test the `_get_offset` functionality with that TORCH_SERIALIZATION_DEBUG flag that is set in CI, so I didn't add more testing specifically for FakeTensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/157931 Approved by: https://github.com/albanD --- torch/serialization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/serialization.py b/torch/serialization.py index 9660b4ec3cbcf..61a4acf684152 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1988,7 +1988,7 @@ def _get_offset(key, name, numel): # for a given key. offsets[name] = storage_offset - # Increment current_offset of offset where next zipfile header starts + # Increment current_offset to offset where next zipfile header starts current_offset = storage_offset + numel # add size of data descriptor after payload if numel > 0: @@ -2004,7 +2004,10 @@ def load_tensor(dtype, numel, key, location): if torch._guards.detect_fake_mode(None) is not None: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes, device="meta") - storage._checkpoint_offset = zip_file.get_record_offset(name) + if can_calculate_storage_offsets: + storage._checkpoint_offset = _get_offset(key, name, numel) + else: + storage._checkpoint_offset = zip_file.get_record_offset(name) elif _serialization_tls.skip_data: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes) From 1b91954b9ffc2416532fe4d41ed8a97fd974a253 Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 17 Jul 2025 22:21:00 +0000 Subject: [PATCH 212/457] Suppress volatile type error (#158435) Fixes ``` /var/lib/jenkins/workspace/torch/csrc/dynamo/guards.cpp:5320:10: error: compound assignment to object of volatile-qualified type 'volatile char' is deprecated [-Werror,-Wdeprecated-volatile] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158435 Approved by: https://github.com/janeyx99 --- torch/csrc/dynamo/guards.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 2b2d09d8b169b..e8a2ebfce6f77 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -5528,6 +5528,7 @@ void install_storage_overlapping_guard( /* overlapping= */ false); } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-volatile") char flush_cache_by_eviction() { constexpr size_t evict_size = 32 * 1024 * 1024; std::vector buffer(evict_size, 1); @@ -5538,6 +5539,7 @@ char flush_cache_by_eviction() { } return sink; } +C10_DIAGNOSTIC_POP() double profile_guard_manager( RootGuardManager* root, From 74f4cf4bd5aaa0123e7b3d91cc0cbcbd69030015 Mon Sep 17 00:00:00 2001 From: Paul Ganssle Date: Thu, 17 Jul 2025 22:23:01 +0000 Subject: [PATCH 213/457] Add missing in c10/util/WaitCounter.h (#158354) It seems that `#include ` is being pulled in indirectly, but it is being used directly, so it is best to explicitly include it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158354 Approved by: https://github.com/janeyx99 --- c10/util/WaitCounter.h | 1 + 1 file changed, 1 insertion(+) diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index 193740cb10dbf..c87c2e3293e5d 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include From 0ecfb93a0bfad553b98047ed79fb2b9a54052bb8 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Thu, 17 Jul 2025 22:31:56 +0000 Subject: [PATCH 214/457] Avoid globally modifying torch.testing._internal.common_methods_invocations.wrapper_set_seed (#158548) Test modules that depend on the original definition of `wrapper_set_seed` will inadvertently be affected if they import from test_torchinductor_opinfo.py. Additionally, using pytest `test_torchinductor_opinfo.py test_other_module.py` when run in the same process may affect the test behaviour of `test_other_module.py` if the tests depend on `wrapper_set_seed`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158548 Approved by: https://github.com/janeyx99 --- test/inductor/test_torchinductor_opinfo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 8abd17aab2f88..2b8ace9db4c6c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -366,8 +366,9 @@ def wrapper_noop_set_seed(op, *args, **kwargs): return op(*args, **kwargs) -torch.testing._internal.common_methods_invocations.wrapper_set_seed = ( - wrapper_noop_set_seed +wrapper_noop_set_seed_decorator = patch( + "torch.testing._internal.common_methods_invocations.wrapper_set_seed", + wrapper_noop_set_seed, ) # key can be either op_name, or (op_name, dtype) @@ -980,6 +981,7 @@ def inner(self, device, dtype, op): return inner +@wrapper_noop_set_seed_decorator class TestInductorOpInfo(TestCase): def tearDown(self): torch._dynamo.reset() From 2df2e3bb511eb3d72742334b116e97656539570d Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 17 Jul 2025 22:52:16 +0000 Subject: [PATCH 215/457] [ROCm][CI] Last known good HIP patch (#158596) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/158596 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/docker/common/install_rocm.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 2b2bb47ea0946..d2d56ecec91df 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -82,30 +82,30 @@ EOF done # ROCm 6.3 had a regression where initializing static code objects had significant overhead + # CI no longer builds for ROCm 6.3, but # ROCm 6.4 did not yet fix the regression, also HIP branch names are different - if [[ $(ver $ROCM_VERSION) -ge $(ver 6.3) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then + if [[ $(ver $ROCM_VERSION) -ge $(ver 6.4) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4.1) ]]; then HIP_BRANCH=release/rocm-rel-6.4 - VER_STR=6.4 - VER_PATCH=.1 + CLR_HASH=ca18eb3f77fa09292fcda62bc60c3e565d752ada # branch release/rocm-rel-6.4.1-statco-hotfix elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then HIP_BRANCH=release/rocm-rel-6.4 - VER_STR=6.4 - elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then - HIP_BRANCH=rocm-6.3.x - VER_STR=6.3 + CLR_HASH=600f5b0d2baed94d5121e2174a9de0851b040b0c # branch release/rocm-rel-6.4-statco-hotfix fi # clr build needs CppHeaderParser but can only find it using conda's python python -m pip install CppHeaderParser git clone https://github.com/ROCm/HIP -b $HIP_BRANCH HIP_COMMON_DIR=$(readlink -f HIP) - git clone https://github.com/jeffdaily/clr -b release/rocm-rel-${VER_STR}${VER_PATCH}-statco-hotfix + git clone https://github.com/jeffdaily/clr + pushd clr + git checkout $CLR_HASH + popd mkdir -p clr/build pushd clr/build # Need to point CMake to the correct python installation to find CppHeaderParser cmake .. -DPython3_EXECUTABLE=/opt/conda/envs/py_${ANACONDA_PYTHON_VERSION}/bin/python3 -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR make -j - cp hipamd/lib/libamdhip64.so.${VER_STR}.* /opt/rocm/lib/libamdhip64.so.${VER_STR}.* + cp hipamd/lib/libamdhip64.so.6.4.* /opt/rocm/lib/libamdhip64.so.6.4.* popd rm -rf HIP clr fi From f63988ae00f856d8a3a6f748310962ba55361f0b Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Thu, 17 Jul 2025 23:24:50 +0000 Subject: [PATCH 216/457] [BE]Clean up old APIs in AOTI c shim (#158400) Summary: The shims for aten ops are now generated by torchgen. But there are some still old APIs in `aoti_torch/c/shim.h` This diff moves the old to-be-deprecated APIs for aten ops to a separate header file `shim_deprecated.h` The to-be-deprecated APIs are determined by comparing APIs in `shim.h` and ops in `fallback_ops.py` Test Plan: CI Rollback Plan: Differential Revision: D78378373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158400 Approved by: https://github.com/jingsh, https://github.com/desertfire --- torch/csrc/inductor/aoti_torch/c/macros.h | 63 ++++++ torch/csrc/inductor/aoti_torch/c/shim.h | 212 +----------------- .../inductor/aoti_torch/c/shim_deprecated.h | 199 ++++++++++++++++ 3 files changed, 264 insertions(+), 210 deletions(-) create mode 100644 torch/csrc/inductor/aoti_torch/c/macros.h create mode 100644 torch/csrc/inductor/aoti_torch/c/shim_deprecated.h diff --git a/torch/csrc/inductor/aoti_torch/c/macros.h b/torch/csrc/inductor/aoti_torch/c/macros.h new file mode 100644 index 0000000000000..6f1346cdcf86a --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/macros.h @@ -0,0 +1,63 @@ +#ifndef AOTI_TORCH_MACRO_H +#define AOTI_TORCH_MACRO_H + +#include +#include +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +#ifdef __cplusplus +extern "C" { +#endif +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque*; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // AOTI_TORCH_MACRO_H diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 6a23c9d465c7f..a155f6bb621f1 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -1,8 +1,8 @@ #ifndef AOTI_TORCH_SHIM #define AOTI_TORCH_SHIM -#include -#include +#include +#include // This header defines a stable C API for certain ATen functionality in // libtorch. The AOTInductor compiled model.so will only refer to this header @@ -36,29 +36,6 @@ // maintain the old and new versions of the APIs until all old model.so // go out of use. -#ifdef __GNUC__ -#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) -#else // !__GNUC__ -#ifdef _WIN32 -// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead -// to symbol clashes at link time if libtorch is included in a DLL and binary -// that depends on the DLL. As a short term fix, we don't export the symbols. -// In the long term, this will need to be addressed when Windows is supported. -#ifdef OVRSOURCE -// Do not export AOTI on Windows for internal builds -#define AOTI_TORCH_EXPORT -#else /* OVRSOURCE */ -#ifdef EXPORT_AOTI_FUNCTIONS -#define AOTI_TORCH_EXPORT __declspec(dllexport) -#else -#define AOTI_TORCH_EXPORT __declspec(dllimport) -#endif -#endif /* OVRSOURCE */ -#else // !_WIN32 -#define AOTI_TORCH_EXPORT -#endif // _WIN32 -#endif // __GNUC__ - // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check #include @@ -69,33 +46,6 @@ extern "C" { #endif -// AtenTensorHandle represents an abstract notion of Tensor that can be passed -// between model.so and libtorch.so. The contents of the structure itself -// are private; model.so is not allowed to access any fields directly, it must -// go through functions defined in this ABI. Under the hood, this is -// represented as at::Tensor*, but we reserve the right to change this (and in -// fact, we probably should change it to at::TensorImpl* at least). -// -// An AtenTensorHandle can be owning (please check the API reference for exact -// ownership/borrow semantics). If you have an owning AtenTensorHandle -// in model.so, you are obligated to aoti_torch_delete_tensor_object when you -// are done. You can use the helper C++ class RAIIAtenTensorHandle -// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style -// (note that RAIIAtenTensorHandle is private to model.so, and never crosses -// the ABI boundary.) -struct AtenTensorOpaque; -using AtenTensorHandle = AtenTensorOpaque*; - -struct AtenGeneratorOpaque; -using AtenGeneratorHandle = AtenGeneratorOpaque*; - -struct AOTIProxyExecutorOpaque; -using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; - -using AOTITorchError = int32_t; -#define AOTI_TORCH_SUCCESS 0 -#define AOTI_TORCH_FAILURE 1 - // Getter functions for retrieving various constants from the runtime, that // can subsequently be passed to other aoti_* functions. By hiding these // behind functions, the precise value of device/dtype is NOT part of the @@ -349,127 +299,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( const uint8_t* opaque_metadata, int64_t opaque_metadata_size); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( - AtenTensorHandle weight, - AtenTensorHandle indices, - AtenTensorHandle offsets, - int32_t scale_grad_by_freq, - int32_t mode, - int32_t sparse, - AtenTensorHandle per_sample_weights, // optional argument - int32_t include_last_offset, - int32_t padding_idx, - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( - AtenTensorHandle self, - const int64_t* dim_ptr, - int64_t dim_size, - int64_t normalization, - int32_t forward, - AtenTensorHandle* ret // returns new reference -); - -// This version is deprecated. We will remove it later -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - double dropout_p, - bool is_causal, - bool return_debug_mask, - double scale, - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3, // returns new reference - int64_t* ret4, - int64_t* ret5, - AtenTensorHandle* ret6, // returns new reference - AtenTensorHandle* ret7, // returns new reference - AtenTensorHandle* ret8 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch__scaled_dot_product_flash_attention_v2( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - double dropout_p, - int is_causal, - int return_debug_mask, - double* scale, // optional argument - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3, // returns new reference - int64_t* ret4, - int64_t* ret5, - AtenTensorHandle* ret6, // returns new reference - AtenTensorHandle* ret7, // returns new reference - AtenTensorHandle* ret8 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch__scaled_dot_product_efficient_attention( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - AtenTensorHandle attn_bias, // optional argument - int compute_log_sumexp, - double dropout_p, - int is_causal, - double* scale, // optional argument - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( - AtenTensorHandle self, - AtenTensorHandle mat2, - AtenTensorHandle bias, - int32_t* out_dtype, - AtenTensorHandle scale_a, - AtenTensorHandle scale_b, - AtenTensorHandle scale_result, - int8_t use_fast_accum, - AtenTensorHandle* ret0, - AtenTensorHandle* ret1); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( - AtenTensorHandle self, - AtenTensorHandle mat2, - AtenTensorHandle scale_a, - AtenTensorHandle scale_b, - AtenTensorHandle bias, - AtenTensorHandle scale_result, - int32_t* out_dtype, - int8_t use_fast_accum, - AtenTensorHandle* ret0); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( - AtenTensorHandle input, - AtenTensorHandle weight, - AtenTensorHandle bias, // optional argument - const int64_t* stride_ptr, - int64_t stride_size, - const int64_t* padding_ptr, - int64_t padding_size, - const int64_t* dilation_ptr, - int64_t dilation_size, - int transposed, - const int64_t* output_padding_ptr, - int64_t output_padding_size, - int64_t groups, - AtenTensorHandle* ret // returns new reference -); - // This function will create a new uninitialized tensor object // and its pointer is returned through *ret. AOTI_TORCH_EXPORT AOTITorchError @@ -502,29 +331,11 @@ aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle* ret); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat1, - AtenTensorHandle mat2, - float beta, - float alpha); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat2); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_( AtenTensorHandle self, AtenTensorHandle src, int32_t non_blocking); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat2); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( AtenTensorHandle out, AtenTensorHandle a, @@ -571,16 +382,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( int64_t out_channel, AtenTensorHandle* out); -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( - AtenTensorHandle repeats, - int64_t* output_size, - AtenTensorHandle* out); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); @@ -608,17 +411,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( const AtenTensorHandle values, bool accumulate); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( - AtenTensorHandle self, - AtenTensorHandle* ret // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( - AtenTensorHandle self, - int32_t dtype, - AtenTensorHandle* ret // returns new reference -); - AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg); diff --git a/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h b/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h new file mode 100644 index 0000000000000..964db6b0076c9 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h @@ -0,0 +1,199 @@ +#ifndef AOTI_TORCH_SHIM_DEPRECATED +#define AOTI_TORCH_SHIM_DEPRECATED + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +[[deprecated( + "aoti_torch__embedding_bag is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, + AtenTensorHandle indices, + AtenTensorHandle offsets, + int32_t scale_grad_by_freq, + int32_t mode, + int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, + int32_t padding_idx, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +[[deprecated( + "aoti_torch__fft_c2c is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, + const int64_t* dim_ptr, + int64_t dim_size, + int64_t normalization, + int32_t forward, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_mm is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1); + +[[deprecated( + "aoti_torch__scaled_mm_v2 is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle bias, + AtenTensorHandle scale_result, + int32_t* out_dtype, + int8_t use_fast_accum, + AtenTensorHandle* ret0); + +[[deprecated( + "aoti_torch_addmm_out is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, + float alpha); + +[[deprecated( + "aoti_torch_bmm is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +[[deprecated( + "aoti_torch_convolution is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t* stride_ptr, + int64_t stride_size, + const int64_t* padding_ptr, + int64_t padding_size, + const int64_t* dilation_ptr, + int64_t dilation_size, + int transposed, + const int64_t* output_padding_ptr, + int64_t output_padding_size, + int64_t groups, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch_mm_out is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +[[deprecated( + "aoti_torch_nonzero is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); + +[[deprecated( + "aoti_torch_repeat_interleave_Tensor is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, + int64_t* output_size, + AtenTensorHandle* out); + +[[deprecated( + "aoti_torch_view_as_real is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch_view_dtype is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, + int32_t dtype, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_dot_product_flash_attention is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_dot_product_efficient_attention is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, + double dropout_p, + int is_causal, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +#ifdef __cplusplus +} // extern "C" + +#endif +#endif // AOTI_TORCH_SHIM_DEPRECATED From b0e325c2c85c5d056a394aa9201f246ee25f8d26 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Thu, 17 Jul 2025 23:31:23 +0000 Subject: [PATCH 217/457] [Dynamo][Better Engineering] Add type coverage to decorators (#158509) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to an important file in dynamo, `decorators.py` NOTE: Untyped fns are because there is a conflict with `__init__.py` in compiler so we can't type these at this time Running ``` mypy torch/_dynamo/decorators.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 209 | 908 | 23.02% | 9 | 39 | 23.08% | | This PR | 870 | 943 | 100.00% | 36 | 39 | 100.00% | | Delta | +661 | +35 | +76.98% | +27 | 0 | +76.92% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158509 Approved by: https://github.com/williamwen42 --- torch/_dynamo/decorators.py | 125 +++++++++++++++++++++++------------- torch/compiler/__init__.py | 9 ++- 2 files changed, 87 insertions(+), 47 deletions(-) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 5e2e2cb4106c3..13b61b7fa3e3d 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs -# ruff: noqa: TCH004 - """ This module provides decorators and utilities for controlling TorchDynamo's behavior during compilation. """ @@ -9,10 +6,12 @@ import inspect import weakref from dataclasses import dataclass +from types import TracebackType from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch +from torch.compiler import is_compiling from torch.utils._contextlib import _DecoratorContextManager from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -29,7 +28,6 @@ from .exc import IncorrectUsage from .external_utils import ( get_nonrecursive_disable_wrapper, - is_compiling, wrap_dunder_call_ctx_manager, ) from .utils import _get_error_on_graph_break, _set_error_on_graph_break, is_function @@ -56,9 +54,11 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) -def run(fn=None): +def run(fn: Optional[Callable[_P, _R]] = None) -> Any: """Don't do any dynamic compiles, just use prior optimizations""" if fn is not None: fn = innermost_fn(fn) @@ -67,7 +67,7 @@ def run(fn=None): return RunOnlyContext() -def disable(fn=None, recursive=True, *, reason=None, wrapping=True): +def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ignore[no-untyped-def] """ Decorator to disable TorchDynamo @@ -87,7 +87,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): return DisableContext(msg=reason, wrapping=wrapping) else: - def wrap(fn): + def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: fn = innermost_fn(fn) assert callable(fn) @@ -106,7 +106,7 @@ def wrap(fn): skip_code(_nonrecursive_disable_wrapper_code) -def skip(fn=None): +def skip(fn: Optional[Callable[_P, _R]] = None) -> Callable[..., Any]: """ Skip frames associated with the function code, but still process recursively invoked frames @@ -134,7 +134,7 @@ def __init__( stance: str = "default", *, skip_guard_eval_unsafe: bool = False, - force_backend=None, + force_backend: Union[str, Callable[..., Any], None] = None, ) -> None: if force_backend is not None and stance != "default": raise RuntimeError("non-default stance cannot have force_backend set") @@ -142,29 +142,34 @@ def __init__( self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend) self.prev = _set_stance(self.stance) - def __call__(self, fn): + def __call__(self, fn: F) -> F: _set_stance(self.prev) wrapper = super().__call__(fn) # forbid wrapper in graph wrapper._dynamo_forbidden = True # type: ignore[attr-defined] return wrapper - def __enter__(self): + def __enter__(self) -> None: _set_stance(self.stance) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: _set_stance(self.prev) - def clone(self): + def clone(self) -> "set_stance": return self.__class__(self.stance.stance, force_backend=self.stance.backend) -def assume_constant_result(fn): - fn._dynamo_marked_constant = True +def assume_constant_result(fn): # type: ignore[no-untyped-def] + fn._dynamo_marked_constant = True # type: ignore[attr-defined] return fn -def allow_in_graph(fn): +def allow_in_graph(fn): # type: ignore[no-untyped-def] """ Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function and instead directly write it to the graph when encountered. @@ -182,14 +187,14 @@ def allow_in_graph(fn): trace_rules._allowed_callable_ids.add(fn_id) # Avoid id reuse which creates subtle bugs. - def deregister(): + def deregister() -> None: trace_rules._allowed_callable_ids.remove(fn_id) weakref.finalize(fn, deregister) return fn -def nonstrict_trace(traceable_fn): +def nonstrict_trace(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: # Like `allow_in_graph`, but with the following enhancements/differences: # # 1. Supports user-defined class as inputs, as long as the class has been @@ -210,7 +215,7 @@ def nonstrict_trace(traceable_fn): assert callable(traceable_fn), "nonstrict_trace expects a callable" @functools.wraps(traceable_fn) - def wrapped(*args, **kwargs): + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return traceable_fn(*args, **kwargs) wrapped_id = id(wrapped) @@ -222,7 +227,7 @@ def wrapped(*args, **kwargs): trace_rules._nonstrict_trace_callable_ids.add(wrapped_id) # Avoid id reuse which creates subtle bugs. - def deregister(): + def deregister() -> None: trace_rules._allowed_callable_ids.remove(wrapped_id) trace_rules._nonstrict_trace_callable_ids.remove(wrapped_id) @@ -231,8 +236,8 @@ def deregister(): return wrapped -def _disallow_in_graph_helper(throw_if_not_allowed): - def inner(fn): +def _disallow_in_graph_helper(throw_if_not_allowed: bool) -> Callable[..., Any]: + def inner(fn: Any) -> Any: if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] assert callable(fn), "disallow_in_graph expects a callable" @@ -254,7 +259,7 @@ def inner(fn): return inner -def disallow_in_graph(fn): +def disallow_in_graph(fn: Callable[..., Any]) -> Any: """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. @@ -280,17 +285,17 @@ def fn(a): @_disallow_in_graph_helper(throw_if_not_allowed=False) -def graph_break(msg=""): +def graph_break(msg: str = "") -> None: """Force a graph break""" # NOTE: primarily used for internal debugging purposes! @_disallow_in_graph_helper(throw_if_not_allowed=False) -def skip_frame(msg=""): +def skip_frame(msg: str = "") -> None: """Force a skipped frame""" -def forbid_in_graph(fn): +def forbid_in_graph(fn: Any) -> Any: """ Customize which functions TorchDynamo will assert are not present while tracing. @@ -392,7 +397,9 @@ def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: else: traceable_sig = inspect.signature(traceable_fn) - def sig_ident(sig): + def sig_ident( + sig: inspect.Signature, + ) -> tuple[tuple[str, ...], set[str], dict[str, Any]]: # Ignore annotations for parameters and return type return ( tuple( @@ -472,7 +479,9 @@ def sig_ident(sig): def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return original_fn(*args, **kwargs) - def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: + def dispatch_fn( + self: VariableBuilder, value: Callable[_P, _R] + ) -> PolyfilledFunctionVariable: return PolyfilledFunctionVariable( value, source=self.source, @@ -497,7 +506,9 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: # Helper function to flatten a tensor subclass and apply a function to # all inner tensors that match the outer dim. Used to reduce duplication # across the various marking APIs. -def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): +def _apply_func_to_inner_tensors_of_same_dim( + func: Callable[..., Any], t: object, *args: Any, **kwargs: Any +) -> None: assert is_traceable_wrapper_subclass(t) attrs, _ctx = t.__tensor_flatten__() @@ -522,7 +533,12 @@ class directly; instead, use :func:`mark_dynamic`. @forbid_in_graph -def mark_unbacked(t, index, strict=False, specialize_on=None): +def mark_unbacked( + t: Any, + index: Union[int, list[Any], tuple[Any]], + strict: bool = False, + specialize_on: Optional[list[Any]] = None, +) -> None: """ Mark a tensor as having an unbacked dim. This changes the semantics of operations, we will always report the size does not equal zero/one, we will turn asserts @@ -565,7 +581,14 @@ def mark_unbacked(t, index, strict=False, specialize_on=None): @forbid_in_graph -def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): +def mark_dynamic( + t: Any, + index: Union[int, list[Any], tuple[Any]], + *, + min: Optional[int] = None, + max: Optional[int] = None, + specialize_on: Optional[list[Any]] = None, +) -> None: """ Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. @@ -620,7 +643,7 @@ def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): # TODO(voz): Should we bounds check? t._dynamo_dynamic_indices.add(index) - t._dynamo_dynamic_range.add(_DimRange(index, min, max)) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type] # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # TypeError: 'Attribute' object does not support item assignment @@ -636,7 +659,7 @@ def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): @forbid_in_graph -def maybe_mark_dynamic(t, index): +def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None: """ Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this dimension ends up getting specialized, don't error). @@ -658,7 +681,9 @@ def maybe_mark_dynamic(t, index): maybe_mark_dynamic(t, i) -def mark_static(t, index=None): +def mark_static( + t: Any, index: Optional[Union[int, list[Any], tuple[Any]]] = None +) -> None: """ Mark a tensor as having a static dim or mark a nn module class as static. @@ -723,7 +748,7 @@ def mark_static(t, index=None): @forbid_in_graph -def mark_static_address(t, guard=True): +def mark_static_address(t: Any, guard: bool = True) -> None: """ Marks an input tensor whose data_ptr will not change across multiple calls to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation @@ -742,7 +767,7 @@ def mark_static_address(t, guard=True): # One day, Dynamo will support tracing into einops directly (no allow_in_graph needed) # Note that PyTorch supports multiple versions of einops, so when that day comes, # we still need to be really careful about version matches. -def _allow_in_graph_einops(): +def _allow_in_graph_einops() -> None: import einops try: @@ -773,21 +798,26 @@ def _allow_in_graph_einops(): # Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators # created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch. class DynamoConfigPatchProxy: - def __init__(self, config_patch): + def __init__(self, config_patch: Any) -> None: self.config_patch = config_patch @property - def changes(self): + def changes(self) -> dict[str, Any]: return self.config_patch.changes # Decorator implementation that simply sets up `self` as a context manager. # Placed in external_utils so that we can trace through it. __call__ = wrap_dunder_call_ctx_manager - def __enter__(self): + def __enter__(self) -> None: return self.config_patch.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: return self.config_patch.__exit__(exc_type, exc_val, exc_tb) @@ -819,7 +849,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): del config -def _patch_dynamo_config_check(changes: dict[str, Any]): +def _patch_dynamo_config_check(changes: dict[str, Any]) -> None: for k, v in changes.items(): if k not in _allowed_config_patches: raise ValueError( @@ -871,7 +901,7 @@ def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ... def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ... -def dont_skip_tracing(fn=None): +def dont_skip_tracing(fn: Optional[Any] = None) -> Any: """ Context manager/decorator to trace into functions intentionally marked by developers to be skipped when tracing. @@ -885,16 +915,21 @@ def dont_skip_tracing(fn=None): class SetFullgraphDecoratorContextManager: - def __init__(self, fullgraph): + def __init__(self, fullgraph: bool) -> None: self.fullgraph = fullgraph __call__ = wrap_dunder_call_ctx_manager - def __enter__(self): + def __enter__(self) -> None: self.prev_fullgraph = _get_error_on_graph_break() _set_error_on_graph_break(self.fullgraph) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: _set_error_on_graph_break(self.prev_fullgraph) diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index e92100e87f384..163c25f12dbc8 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -39,6 +39,8 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) def compile(*args, **kwargs): @@ -252,7 +254,10 @@ def disable(fn=None, recursive=True, *, reason=None): def set_stance( - stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None + stance: str = "default", + *, + skip_guard_eval_unsafe: bool = False, + force_backend: Union[str, Callable[..., Any], None] = None, ): """ Set the current stance of the compiler. From 33c9b414aaa59ab03b7969599afc0de915353519 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 17 Jul 2025 13:43:10 -0700 Subject: [PATCH 218/457] [CI][MPS] Enable test_indexing on MPS (#158582) - Skip `test_index_put_accumulate_large_tensor_mps` as it crashes with ``` /com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' ``` while running `torch.ones([2**31+5], dtype=torch.int8, device='mps')` - Adjust types for `test_index_put_src_datatype` as index_put on MPS is not implemented for complex (yet) - Adjust `test_index` to avoid using DoubleTensors for MPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/158582 Approved by: https://github.com/dcci, https://github.com/Skylion007, https://github.com/manuelcandales --- test/test_indexing.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/test/test_indexing.py b/test/test_indexing.py index 58854a995db6f..37a12f00ab272 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -15,6 +15,7 @@ dtypes, dtypesIfCPU, dtypesIfCUDA, + dtypesIfMPS, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, @@ -140,7 +141,10 @@ def consec(size, start=1): ) lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] - tensor = torch.DoubleTensor(lst).to(device) + _make_tensor = ( + torch.DoubleTensor if not device.startswith("mps") else torch.FloatTensor + ) + tensor = _make_tensor(lst).to(device) for _i in range(100): idx1_start = random.randrange(10) idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) @@ -156,7 +160,7 @@ def consec(size, start=1): else: lst_indexed = lst[idx1] tensor_indexed = tensor[idx1] - self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) + self.assertEqual(_make_tensor(lst_indexed), tensor_indexed) self.assertRaises(ValueError, lambda: reference[1:9:0]) self.assertRaises(ValueError, lambda: reference[1:9:-1]) @@ -994,6 +998,8 @@ def test_byte_mask_accumulate(self, device): ) @serialTest(TEST_CUDA) def test_index_put_accumulate_large_tensor(self, device): + if device.startswith("mps"): + raise unittest.SkipTest("Crash with max number of dimentions") # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 dt = torch.int8 @@ -1303,6 +1309,7 @@ def test_int_indices(self, device): torch.float8_e5m2, torch.float8_e4m3fn, ) + @dtypesIfMPS(torch.float, torch.float16, torch.long, torch.bool) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) @@ -2049,7 +2056,9 @@ def test_truncate_leading_1s(self, device): self.assertEqual(kernel, kernel2) -instantiate_device_type_tests(TestIndexing, globals(), except_for="meta") +instantiate_device_type_tests( + TestIndexing, globals(), except_for="meta", allow_mps=True +) instantiate_device_type_tests(NumpyTests, globals(), except_for="meta") if __name__ == "__main__": From 7b72e5b3ad989d02802909b64f322b2b7b69913b Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Fri, 18 Jul 2025 00:14:12 +0000 Subject: [PATCH 219/457] Fix Pandas version mismatch upon reinstalling numpy (#158584) If you reinstall numpy after having installed pandas, it will error out sometimes if the versions are different enough (see below snippet). This change forces pandas to be reinstalled when installing numpy. It doesn't work in a separate pip call, because then pip takes the version of numpy requested by pandas as the one to install, undoing the command in the first place. ``` (numpy_pandas) [gabeferns@devvm2497.eag0 ~/pt-envs/at (exclamaforte/just-gemm-model)]$ pip list Package Version ------------------ ----------- attrs 25.3.0 build 1.2.2.post1 certifi 2025.7.14 charset-normalizer 3.4.2 cmake 4.0.3 exceptiongroup 1.3.0 expecttest 0.3.0 filelock 3.18.0 fsspec 2025.5.1 hypothesis 6.135.32 idna 3.10 importlib_metadata 8.7.0 Jinja2 3.1.6 lintrunner 0.12.7 MarkupSafe 2.1.5 mpmath 1.3.0 networkx 3.2.1 ninja [1.11.1.4](https://www.internalfb.com/phabricator/paste/view/1.11.1.4) opt-einsum 3.3.0 optree 0.16.0 packaging 25.0 pip 25.1 psutil 7.0.0 pyproject_hooks 1.2.0 python-dateutil 2.9.0.post0 pytz 2025.2 PyYAML 6.0.2 requests 2.32.4 setuptools 78.1.1 six 1.17.0 sortedcontainers 2.4.0 sympy 1.14.0 tomli 2.2.1 typing_extensions 4.14.0 tzdata 2025.2 urllib3 2.5.0 uv 0.7.21 wheel 0.45.1 zipp 3.23.0 (numpy_pandas) [gabeferns@devvm2497.eag0 ~/pt-envs/at (exclamaforte/just-gemm-model)]$ pip install numpy==1.22.4 Collecting numpy==1.22.4 Using cached numpy-1.22.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.0 kB) Using cached numpy-1.22.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB) Installing collected packages: numpy Successfully installed numpy-1.22.4 (numpy_pandas) [gabeferns@devvm2497.eag0 ~/pt-envs/at (exclamaforte/just-gemm-model)]$ pip install pandas==2.0.3 Collecting pandas==2.0.3 Using cached pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB) Requirement already satisfied: python-dateutil>=2.8.2 in /home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages (from pandas==2.0.3) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages (from pandas==2.0.3) (2025.2) Requirement already satisfied: tzdata>=2022.1 in /home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages (from pandas==2.0.3) (2025.2) Requirement already satisfied: numpy>=1.20.3 in /home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages (from pandas==2.0.3) (1.22.4) Requirement already satisfied: six>=1.5 in /home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas==2.0.3) (1.17.0) Using cached pandas-2.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB) Installing collected packages: pandas Successfully installed pandas-2.0.3 (numpy_pandas) [gabeferns@devvm2497.eag0 ~/pt-envs/at (exclamaforte/just-gemm-model)]$ pip install --pre numpy==2.0.2 Collecting numpy==2.0.2 Using cached numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB) Using cached numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.5 MB) Installing collected packages: numpy Attempting uninstall: numpy Found existing installation: numpy 1.22.4 Uninstalling numpy-1.22.4: Successfully uninstalled numpy-1.22.4 Successfully installed numpy-2.0.2 (numpy_pandas) [gabeferns@devvm2497.eag0 ~/pt-envs/at (exclamaforte/just-gemm-model)]$ python Python 3.9.23 (main, Jun 5 2025, 13:40:20) [GCC 11.2.0] :: Anaconda, Inc. on linux Type "help", "copyright", "credits" or "license" for more information. >>> import pandas Traceback (most recent call last): File "", line 1, in File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/__init__.py", line 22, in from pandas.compat import is_numpy_dev as _is_numpy_dev # pyright: ignore # noqa:F401 File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/compat/__init__.py", line 25, in from pandas.compat.numpy import ( File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/compat/numpy/__init__.py", line 4, in from pandas.util.version import Version File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/util/__init__.py", line 2, in from pandas.util._decorators import ( # noqa:F401 File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/util/_decorators.py", line 14, in from pandas._libs.properties import cache_readonly File "/home/gabeferns/.conda/envs/numpy_pandas/lib/python3.9/site-packages/pandas/_libs/__init__.py", line 13, in from pandas._libs.interval import Interval File "pandas/_libs/interval.pyx", line 1, in init pandas._libs.interval ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158584 Approved by: https://github.com/huydhn --- .ci/pytorch/test.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index a51a7e472c974..d7d5947d2ce2c 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1606,7 +1606,13 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze fi if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then # Install numpy-2.0.2 and compatible scipy & numba versions - python -mpip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 + # Force re-install of pandas to avoid error where pandas checks numpy version from initial install and fails upon import + TMP_PANDAS_VERSION=$(python -c "import pandas; print(pandas.__version__)" 2>/dev/null) + if [ -n "$TMP_PANDAS_VERSION" ]; then + python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 pandas=="$TMP_PANDAS_VERSION" --force-reinstall + else + python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 + fi python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then test_linux_aarch64 From 6673ac746c5fade3eaf0aa37547a2e0e76d81860 Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 18 Jul 2025 01:08:33 +0000 Subject: [PATCH 220/457] Fix test linalg for MKL upgrading (#158312) Fixes #158054 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158312 Approved by: https://github.com/albanD --- test/torch_np/numpy_tests/linalg/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index afda92e5b6b9e..869c4af753915 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -489,7 +489,7 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): # kept apart from TestSolve for use for testing with matrices. def do(self, a, b, tags): x = linalg.solve(a, b) - assert_almost_equal(b, dot_generalized(a, x)) + assert_almost_equal(b, dot_generalized(a, x), single_decimal=5) assert_(consistent_subclass(x, b)) From ef38edb2847c87702db7c3a7c71413eb59f40b2b Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 18 Jul 2025 01:10:55 +0000 Subject: [PATCH 221/457] Add stride check for attn_mask on non-cpu device (#158424) Fixes #158374 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158424 Approved by: https://github.com/Valentine233, https://github.com/drisspg, https://github.com/atalman --- .../ATen/native/transformers/sdp_utils_cpp.h | 31 +++++++++++++------ test/inductor/test_fused_attention.py | 12 ++++++- test/test_transformers.py | 28 +++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index aa5c2b6cdd641..c63ca928613e6 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -503,17 +504,27 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool if (ignore_singleton_dim){ qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; } - if (!qkv_strides_equal_1) { + bool is_cpu = params.query.device().type() == c10::DeviceType::CPU; + bool mask_stride_equal_1 = params.attn_mask.has_value() + ? params.attn_mask.value().sym_stride(-1) == 1 + : true; + bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1; + if (!(qkv_strides_equal_1 && mask_stride_valid)) { if (debug) { - TORCH_WARN( - "All fused kernels require the last dimension of the input to have stride 1. ", - "Got Query.stride(-1): ", - params.query.sym_stride(-1), - ", Key.stride(-1): ", - params.key.sym_stride(-1), - ", Value.stride(-1): ", - params.value.sym_stride(-1), - " instead."); + std::ostringstream message; + message + << "All fused kernels require the last dimension of the input to have stride 1. "; + message << "Got Query.stride(-1): " << params.query.sym_stride(-1) + << ", Key.stride(-1): " << params.key.sym_stride(-1) + << ", Value.stride(-1): " << params.value.sym_stride(-1); + + if (params.attn_mask.has_value()) { + message + << ", Attn_mask.stride(-1): " + << params.attn_mask.value().sym_stride(-1) + << " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not)."; + } + TORCH_WARN(message.str()); } return false; diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 9015332f4e15d..a0e1b47032b86 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -1023,7 +1023,7 @@ def dot_prod_attention( return attn_weights.matmul(value), key, value tensor_shape = (4, 2, 16, 32) - attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) + attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), @@ -1036,6 +1036,16 @@ def dot_prod_attention( has_dropout=False, check_train=False, ) + # test attn_mask with stride of last dim != 1 + attn_mask_ = attn_mask.transpose(2, 3) + args[3] = attn_mask_ + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + contains=self.device == "cpu", + ) def _test_sdpa_rewriter_23(self): def dot_prod_attention( diff --git a/test/test_transformers.py b/test/test_transformers.py index 7c11cb2833d74..89db8d798c266 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1618,6 +1618,34 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) + @onlyCUDA + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + or not PLATFORM_SUPPORTS_CUDNN_ATTENTION, + "Efficient or cuDNN Attention was not built for this system", + ) + @parametrize("kernel", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION]) + def test_mask_invalid_last_dim_stride(self, device, kernel): + with sdpa_kernel(backends=[kernel]): + dtype = torch.float16 + make_tensor = partial(torch.rand, device=device, dtype=dtype) + size = SdpaShape(2, 2, 8, 8) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + attn_mask = make_tensor((2, 2, 8, 8)) + # Passing in a attn_mask with last dim stride not equal to 1 will error + attn_mask.as_strided_(size, [2, 2, 2, 2]) + + with self.assertWarnsRegex( + UserWarning, + "GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not", + ): + self.assertRaises( + RuntimeError, + lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, 0.0, False + ), + ) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) From ddbecdfb663172512875db4a873d8a4913a9ac83 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 13:30:36 -0700 Subject: [PATCH 222/457] [DTensor] Document redistribute_costs (#158495) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158495 Approved by: https://github.com/zpcore, https://github.com/XilunWu --- torch/distributed/tensor/_op_schema.py | 31 +++++++++++++++++++++++--- torch/distributed/tensor/_ops/utils.py | 5 +++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index ccc006e63a83a..0adaa2e4ad082 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -73,12 +73,37 @@ class OpSpec: invariant: the DeviceMesh on all DTensorSpec must be the same """ + # output_specs and input_specs are related: for this op, given these input_specs, + # this is the way the output would look output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] input_specs: Optional[Sequence[DTensorSpec]] = None - # redistribute costs to redistribute the operator input shardings to this OpSpec. - # Note that We need a nested list to record the cost for each operand of this - # operator, and for each operand of this operator it might have multiple OpSpecs. + """ + redistribute_cost tells how expensive it is to redistribute a given input into the + placement specified in this OpSpec. + + outer list: one entry (list) per (tensor) input in the op's arg schema + inner list: one entry (cost value) per possible sharding spec for that input + + Example: + ------- + another_op() -> tensor_a # another_op produces the output that becomes our first input + my_op(tensor_a) + + Let's assume this OpSpec's input_specs are [Replicate()], + but another_op() supports 2 strategies (OpSpecs) which produce outputs of + Replicate() + Shard(0) + + In this example, redistribute_costs would look like this + [ + # one row representing "my_op's first input" (tensor_a) + [ + # two entries, one for each strategies supported by another_op + 0.0, # cost of redistributing tensor_a from 'Replicate()' + K, # cost of redistributing tensor_a from 'Shard(0)' + ], + """ redistribute_cost: Optional[list[list[float]]] = None @cached_property diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index f120b6c39b022..d1c604d2976dd 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -226,6 +226,11 @@ def map_placements_after_broadcast( def generate_redistribute_costs( src_strategy: OpStrategy, dst_spec: DTensorSpec ) -> list[float]: + """Generates one row in the 'redistribute_costs' matrix in an OpSpec + The length of the returned list will match the number of strategies in 'src_strategy'. + + Each value in the row is the cost of redistributing from a particular src_strategy to dst_spec. + """ redistribute_costs: list[float] = [ redistribute_cost(strat.output_spec, dst_spec) for strat in src_strategy.strategies From 6fd6fc418d9846e5e7e73513b9bcea7bf7feb4b4 Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 18 Jul 2025 02:03:44 +0000 Subject: [PATCH 223/457] [B200] Fix flex-attention heuristic for `test_tma_with_customer_kernel_options_cuda` (#158494) Otherwise fails with ``` torch._inductor.exc.InductorError: RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_tem_fused__to_copy_ones_sort_sum_zeros_2 Required: 264224 Hardware limit: 232448 Reducing block sizes or `num_stages` may help. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158494 Approved by: https://github.com/drisspg --- torch/_inductor/template_heuristics.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 40a9645186792..65a6851192a0b 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -707,6 +707,18 @@ class CUDAConfigHeuristic(BaseConfigHeuristic): def __init__(self) -> None: super().__init__() + self.b200_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 2, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + self.h100_default_flex_config = { (torch.float32, 64): FlexConfig(128, 32, 3, 4), (torch.float32, 128): FlexConfig(32, 64, 3, 4), @@ -745,7 +757,11 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi default_config = FlexConfig(64, 64, 3, 4) else: default_config = FlexConfig(128, 64, 3, 4) - if capability >= (9, 0): + if capability >= (10, 0): + default_config = self.b200_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (9, 0): default_config = self.h100_default_flex_config.get( (dtype, head_dim), default_config ) From 583138d170ff8b60a321211cce8d2e5be6c9ae0b Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 18 Jul 2025 02:11:52 +0000 Subject: [PATCH 224/457] [Dynamo][Better Engineering] Add typing for comptime, cache, and convert_frame (#158379) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical tracing point for dynamo, primarily for`comptime.py` but also `cache_size.py` and `convert_frame.py`. Running ``` mypy torch/_dynamo/comptime.py torch/_dynamo/cache_size.py torch/_dynamo/convert_frame.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 1837 | 2215 | 82.93% | 45 | 82 | 54.88% | | This PR | 2230 | 2230 | 100.00% | 82 | 82 | 100.00% | | Delta | +393 | +15 | +17.07% | +37 | 0 | +45.12% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158379 Approved by: https://github.com/mlazos --- torch/_dynamo/cache_size.py | 12 ++-- torch/_dynamo/comptime.py | 109 +++++++++++++++++++-------------- torch/_dynamo/convert_frame.py | 2 - torch/_dynamo/eval_frame.py | 2 +- 4 files changed, 70 insertions(+), 55 deletions(-) diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index cff7ea3fef334..d1a46742f37ac 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,7 +1,7 @@ -# mypy: allow-untyped-defs import logging import weakref from dataclasses import dataclass +from typing import Any, Optional from torch._guards import CompileId @@ -9,7 +9,7 @@ from .types import DynamoFrameType -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) """ [Note on cache size limit] @@ -99,7 +99,9 @@ def will_compilation_exceed_specific_limit(self, limit: int) -> bool: return self.num_cache_entries_with_same_id_matched_objs >= limit -def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): +def _get_weakref_from_f_locals( + frame: DynamoFrameType, local_name: str +) -> Optional[weakref.ref[Any]]: obj = frame.f_locals.get(local_name, None) weak_id = None try: @@ -109,7 +111,7 @@ def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): return weak_id -def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: +def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool: """ Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones in frame.f_locals. @@ -131,7 +133,7 @@ def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: def compute_cache_size( - frame: DynamoFrameType, cache_entry + frame: DynamoFrameType, cache_entry: Any ) -> CacheSizeRelevantForFrame: # Walk the linked list to calculate the cache size num_cache_entries = 0 diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index e21855563efdb..2864168dfb82b 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides the public comptime interface to TorchDynamo, enabling users to execute arbitrary Python code during symbolic evaluation of their programs. @@ -40,9 +38,13 @@ def my_model(x): import dis import time import traceback -from typing import Optional, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, TextIO, Union import torch +from torch._dynamo.symbolic_convert import InstructionTranslatorBase +from torch._dynamo.variables.base import VariableTracker +from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.symbolic_shapes import free_symbols from .exc import unimplemented_v2 @@ -62,10 +64,10 @@ class ComptimeVar: actual data in the Tensor is.) """ - def __init__(self, v) -> None: + def __init__(self, v: VariableTracker) -> None: self.__variable = v - def as_proxy(self): + def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]: """ Returns an fx.Proxy (or tuple/list of fx.Proxy) representing this variable in the FX graph we are assembling to pass @@ -79,13 +81,13 @@ def as_proxy(self): """ return self.__variable.as_proxy() - def is_proxy(self): + def is_proxy(self) -> bool: """ Returns True if as_proxy() would succeed. """ return self.__variable.is_proxy() - def as_fake(self): + def as_fake(self) -> Union[FakeTensor, torch.SymInt]: """ Returns a "fake" value (either a FakeTensor or a SymInt) representing the variable in question. This only works @@ -102,16 +104,16 @@ def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: Returns the size of the tensor (if dim is None) or the size at the dimension dim. The returned size may be a SymInt. """ - return self.as_fake().size(dim) + return self.as_fake().size(dim) # type: ignore[union-attr, return-value] - def python_type(self): + def python_type(self) -> type: """ Returns what type(v) would have returned for the variable at compile time. """ return self.__variable.python_type() - def as_python_constant(self): + def as_python_constant(self) -> Any: """ Returns the Python value this variable would have, but only if it is completely known at compile-time (e.g., it is constant). @@ -123,19 +125,19 @@ def as_python_constant(self): """ return self.__variable.as_python_constant() - def is_python_constant(self): + def is_python_constant(self) -> bool: """ Returns True if as_python_constant would succeed. """ return self.__variable.is_python_constant() - def is_dynamic(self): + def is_dynamic(self) -> bool: if isinstance(self.__variable, SymNodeVariable): fs = free_symbols(self.__variable.sym_num) return bool(fs) return False - def force_static(self): + def force_static(self) -> None: """ Forces that a value is static, inducing a guard on its specific value """ @@ -149,7 +151,7 @@ def force_static(self): f"cannot force {self.__variable} ({type(self.__variable)}) static" ) - def _i_will_not_complain_if_bc_breaks_VariableTracker(self): + def _i_will_not_complain_if_bc_breaks_VariableTracker(self) -> VariableTracker: """ Returns the internal data structure VariableTracker that Dynamo uses to represent variables at compile time. There are no BC guarantees on @@ -171,10 +173,10 @@ class ComptimeContext: file a feature request at https://github.com/pytorch/pytorch/ """ - def __init__(self, tx) -> None: + def __init__(self, tx: InstructionTranslatorBase) -> None: self.__tx = tx - def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: + def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar: """ Retrieve the compile-time known information about a local. """ @@ -187,7 +189,7 @@ def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: return ComptimeVar(var) - def graph_break(self, msg="ComptimeContext.graph_break"): + def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None: """ Manually trigger a graph break """ @@ -198,14 +200,14 @@ def graph_break(self, msg="ComptimeContext.graph_break"): hints=[], ) - def graph(self): + def graph(self) -> torch.fx.Graph: """ Retrieve the partially constructed FX graph that would be passed to the user compiler after compilation. """ return self.__tx.output.graph - def assert_static(self, val): + def assert_static(self, val: ComptimeVar) -> None: """ Asserts that the int is static (and not dynamic, per dynamic shapes) """ @@ -213,7 +215,9 @@ def assert_static(self, val): "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" ) - def print_graph(self, *, verbose=True, file=None): + def print_graph( + self, *, verbose: bool = True, file: Optional[TextIO] = None + ) -> None: """ Print the partially constructed FX graph that would be passed to the user compiler after compilation. @@ -222,19 +226,21 @@ def print_graph(self, *, verbose=True, file=None): self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file ) - def parent(self): - return ComptimeContext(self.__tx.parent) + def parent(self) -> "ComptimeContext": + return ComptimeContext(self.__tx.parent) # type: ignore[arg-type] - def __get_tx(self, stacklevel): + def __get_tx(self, stacklevel: int) -> Any: tx = self.__tx for _ in range(stacklevel): - tx = tx.parent + tx = tx.parent # type: ignore[assignment] return tx - def print(self, val, *, file=None): + def print(self, val: Any, *, file: Optional[TextIO] = None) -> None: print(repr(val), file=file) - def print_disas(self, *, file=None, stacklevel=0): + def print_disas( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print the current series of opcodes being executed (not including parent frames), including where you are in the particular opcode @@ -249,7 +255,9 @@ def print_disas(self, *, file=None, stacklevel=0): file=file, ) - def print_value_stack(self, *, file=None, stacklevel=0): + def print_value_stack( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print the current Python value stack. Note that this is NOT the same as the traceback; use print_bt() to print that. Note that at @@ -264,7 +272,9 @@ def print_value_stack(self, *, file=None, stacklevel=0): for s in tx.stack: print(f"- {s.debug_repr()}", file=file) - def print_locals(self, *, file=None, stacklevel=0): + def print_locals( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print all of the locals available in the current context. By default this view is very limited; you can get more information @@ -274,7 +284,7 @@ def print_locals(self, *, file=None, stacklevel=0): for k, v in tx.symbolic_locals.items(): print(f"{k} = {v.debug_repr()}", file=file) - def print_bt(self, *, file=None, stacklevel=0): + def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> None: """ Print the user code backtrace, starting at the beginning of the frame Dynamo started evaluating. Note that this MAY NOT go all @@ -293,7 +303,7 @@ def print_bt(self, *, file=None, stacklevel=0): file=file, ) - def print_guards(self, *, file=None): + def print_guards(self, *, file: Optional[TextIO] = None) -> None: """ Print the currently installed guards for the Dynamo context. This does NOT include guards associated with variables that @@ -307,7 +317,9 @@ def print_guards(self, *, file=None): file=file, ) - def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): + def _i_will_not_complain_if_bc_breaks_InstructionTranslator( + self, + ) -> InstructionTranslatorBase: """ Returns the internal data structure InstructionTranslator that Dynamo uses to track state of symbolic evaluation. There are no BC @@ -316,32 +328,35 @@ def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): """ return self.__tx - def sleep(self, sec): + def sleep(self, sec: Union[int, float]) -> None: time.sleep(sec) class _Comptime: @staticmethod - def __call__(fn, fallback_fn=lambda: None): + def __call__( + fn: Callable[[ComptimeContext], Any], + fallback_fn: Callable[[], Any] = lambda: None, + ) -> Any: """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" fallback_fn() # Convenience wrappers that are more compact to use @staticmethod - def graph_break(): + def graph_break() -> None: comptime(lambda ctx: ctx.graph_break()) @staticmethod - def print(e): + def print(e: Any) -> None: comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) @staticmethod - def print_graph(): + def print_graph() -> None: comptime(lambda ctx: ctx.print_graph()) @staticmethod - def print_disas(*, stacklevel=0): + def print_disas(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_disas( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -349,7 +364,7 @@ def print_disas(*, stacklevel=0): ) @staticmethod - def print_value_stack(*, stacklevel=0): + def print_value_stack(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -360,7 +375,7 @@ def print_value_stack(*, stacklevel=0): # in an expression context; e.g., x + print_value_stack_and_return(y + z), # you will see x on the stack prior to the addition operation @staticmethod - def print_value_stack_and_return(e, *, stacklevel=0): + def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any: comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -369,7 +384,7 @@ def print_value_stack_and_return(e, *, stacklevel=0): return e @staticmethod - def print_locals(*, stacklevel=0): + def print_locals(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_locals( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -377,7 +392,7 @@ def print_locals(*, stacklevel=0): ) @staticmethod - def print_bt(*, stacklevel=0): + def print_bt(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_bt( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -385,19 +400,19 @@ def print_bt(*, stacklevel=0): ) @staticmethod - def print_guards(): + def print_guards() -> None: comptime(lambda ctx: ctx.print_guards()) @staticmethod - def assert_static(val): + def assert_static(val: Any) -> None: comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) @staticmethod - def force_static(val): + def force_static(val: Any) -> None: comptime(lambda ctx: ctx.get_local("val").force_static()) @staticmethod - def breakpoint(): + def breakpoint() -> None: """ Like pdb breakpoint(), but drop into pdb whenever this line of code is compiled by dynamo. Use it by putting @@ -415,14 +430,14 @@ def breakpoint(): (Pdb) p ctx.get_local("attention").as_fake() """ - def inner(inner_ctx): + def inner(inner_ctx: ComptimeContext) -> None: ctx = inner_ctx.parent() # noqa: F841 builtins.breakpoint() comptime(inner) @staticmethod - def sleep(sec): + def sleep(sec: Union[int, float]) -> None: comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 8fe9e3aaf13a6..149a1c400d99a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-decorators - """ This module implements TorchDynamo's core frame conversion functionality, transforming Python frames into FX graphs. It handles: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e621d7082fe3f..2eaafdc436550 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1705,7 +1705,7 @@ def export( _log_export_usage: bool = True, constraints: Optional[list[Constraint]] = None, **extra_kwargs: Any, -) -> Callable[[tuple[Any, Any]], ExportResult]: +) -> Callable[..., ExportResult]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. From ce4554352be22c7b5c5544330d903851db3120e1 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 14:35:19 -0700 Subject: [PATCH 225/457] Shunt fx_interpreter graphmodule print on error into tlparse (#158469) Include both the error stacktrace and the graphmodule in a new structured trace artifact. Log the shortened version to the console, and also log a hint to look at the tlparse for more. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158469 Approved by: https://github.com/ezyang --- torch/fx/interpreter.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index e2d2f9d7466dd..4e1ab646593a2 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -5,6 +5,7 @@ import torch import torch.fx.traceback as fx_traceback +from torch._logging import trace_structured from torch.hub import tqdm from . import config @@ -175,13 +176,26 @@ def run( if self.extra_traceback: msg = f"While executing {node.format_node()}" msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" if ( isinstance(self.module, GraphModule) and self.module.graph is not None and isinstance(self.module.graph, torch.fx.Graph) ): - msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n" - msg += f"\nOriginal traceback:\n{node.stack_trace}" + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_interpreter_error", + "encoding": "string", + }, + payload_fn=lambda: ( + f"{msg}\nGraphModule: " + f"{self.module.print_readable(print_output=False, include_stride=True)}" # type: ignore[operator] + ), + ) + + msg += "\nUse tlparse to see full graph. " + msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)" e.args = (msg,) + e.args[1:] if isinstance(e, KeyError): raise RuntimeError(*e.args) from e From 89d842fec5229fff0df5342b2db121368d51e717 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 15:32:22 -0700 Subject: [PATCH 226/457] Make torch.distributed.breakpoint() set a long timeout (#158481) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158481 Approved by: https://github.com/d4l3k ghstack dependencies: #158469 --- torch/distributed/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index b6ba9919ee840..38e2fdbee803a 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -4,6 +4,7 @@ import sys import traceback import typing +from datetime import timedelta import torch @@ -82,7 +83,7 @@ def interaction(self, *args, **kwargs): _breakpoint_cache: dict[int, typing.Any] = {} - def breakpoint(rank: int = 0, skip: int = 0): + def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): """ Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing. @@ -99,6 +100,13 @@ def breakpoint(rank: int = 0, skip: int = 0): log.warning("Skip the breakpoint, counter=%d", counter) return + # avoid having the default timeout (if short) interrupt your debug session + if timeout_s is not None: + for group in torch.distributed.distributed_c10d._pg_map: + torch.distributed.distributed_c10d._set_pg_timeout( + timedelta(seconds=timeout_s), group + ) + if get_rank() == rank: pdb = _DistributedPdb() pdb.message( From 86dbc0ef677ba9f9a4ca41fe168775e49cb0f1ba Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Fri, 18 Jul 2025 03:20:40 +0000 Subject: [PATCH 227/457] [NativeRT] Remove makeProxyExecutor from ModelRunner interface (#158587) Summary: makeProxyExecutor shouldn't be exposed to ModelRunner Interface. Test Plan: CI Rollback Plan: Differential Revision: D78501011 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158587 Approved by: https://github.com/yiming0416, https://github.com/henryoier --- torch/nativert/executor/Executor.cpp | 11 ++--------- torch/nativert/executor/Executor.h | 5 +---- torch/nativert/kernels/KernelFactory.cpp | 3 +-- torch/nativert/kernels/KernelFactory.h | 5 ++--- 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 01ce24636fb77..3a3f3d335137d 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -22,8 +22,7 @@ Executor::Executor( const std::shared_ptr& weights, Placement placement, const std::shared_ptr& - pytorchStreamReader, - MakeProxyExecutorFn makeProxyExecutorFunc) + pytorchStreamReader) : executorConfig_(std::move(executorConfig)), graph_(std::move(graph)), placement_(std::move(placement)), @@ -31,7 +30,6 @@ Executor::Executor( executorConfig_.runConstFolding ? std::optional(*graph_) : std::nullopt), - makeProxyExecutorFunc_(std::move(makeProxyExecutorFunc)), executionFrames_(executorConfig_.maxNumConcurrentThreads), clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), numExecutionFrames_(0), @@ -48,12 +46,7 @@ void Executor::initialize( auto start = std::chrono::high_resolution_clock::now(); auto executionKernels = KernelFactory().initializeNodeKernels( - *graph_, - weights, - executorConfig_, - placement_, - pytorchStreamReader, - makeProxyExecutorFunc_); + *graph_, weights, executorConfig_, placement_, pytorchStreamReader); if (constantFolder_.has_value()) { constantFolder_->unlinkConstants(executionKernels.nodeKernels); diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index cd15e846a3c96..57356c36d6c5d 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -82,8 +82,7 @@ class Executor { const std::shared_ptr& weights, Placement placement = Placement(), const std::shared_ptr& - pytorchStreamReader = nullptr, - MakeProxyExecutorFn makeProxyExecutorFunc = nullptr); + pytorchStreamReader = nullptr); std::shared_ptr getWeights() { std::shared_ptr ret; @@ -190,8 +189,6 @@ class Executor { std::optional constantFolder_; - MakeProxyExecutorFn makeProxyExecutorFunc_; - c10::Semaphore sem_; torch::nativert::detail::MPMCQueue> executionFrames_; diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index db055a6cf220c..da524c8e46b93 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -128,8 +128,7 @@ ExecutionKernels KernelFactory::initializeNodeKernels( const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, const std::shared_ptr& - pytorchStreamReader, - const MakeProxyExecutorFn& makeProxyExecutorFunc) { + pytorchStreamReader) { std::vector> nodeKernels; std::vector> delegateExecutors; std::vector constFoldingExecutions; diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index 8c5d5fc661d1d..3f341f1115d37 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -70,7 +70,7 @@ class KernelFactoryHandler { class KernelFactory { public: - explicit KernelFactory() {} + KernelFactory() = default; ExecutionKernels initializeNodeKernels( const Graph& graph, @@ -78,8 +78,7 @@ class KernelFactory { const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, const std::shared_ptr& - pytorchStreamReader = nullptr, - const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + pytorchStreamReader = nullptr); static void registerHandler( const std::string& name, From 1e86fa2e5bed964fcbc1d9d7c43279ce29eb4def Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 18 Jul 2025 04:05:17 +0000 Subject: [PATCH 228/457] Add stack trace to Inductor IR nodes if `inductor.config.trace.provenance_tracing=True` (#158576) Summary: - Split `create_mapping` to `create_mapping_pre_post_grad_nodes` and ` create_node_mapping_kernel_to_post_grad` - Store a mapping from pre_grad graph node names to stack traces in `_inductor_pre_grad_node_stack_trace` - Add `stack_traces` member to ir.Node and add it to the string representation of ir.Node - When we create an IR node, if `inductor.config.trace.provenance_tracing=True`, we populate `stack_traces` from `origins`. The nodes in `origins` are post_grad graph nodes. If a node has `node.stack_trace`, we store the stack_trace directly. This is particularly important for backward graph nodes because they don't have a mapping to pre-grad graph nodes. If a node doesn't have `.stack_trace ` (such as `linear`-> `addmm` nodes), we use the stack trace of the pre_grad graph nodes that it maps to. - A post grad graph node might not have stack trace if it correspond to multiple pre grad graph nodes, e.g. [GroupLinearFusion](https://github.com/pytorch/pytorch/blob/a00442421a14448f95fc28790325f941662d97f2/torch/_inductor/fx_passes/group_batch_fusion.py#L299) Example: ``` scheduling ExternKernelOut( python_kernel_name='extern_kernels.mm', name=buf0, layout=FixedLayout('cuda:0', torch.float32, size=[8, 16], stride=[16, 1]), inputs=[InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.float32, size=[8, 10], stride=[10, 1])), ReinterpretView( StorageBox( ConstantBuffer(name='fc1_weight', layout=FixedLayout('cuda:0', torch.float32, size=[16, 10], stride=[10, 1])) ), FixedLayout('cuda:0', torch.float32, size=[10, 16], stride=[1, 10]), origins=OrderedSet([mm_default_1]), stack_traces = {, File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward, x = self.fc1(x), File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward, return F.linear(input, self.weight, self.bias), } )], constant_args=(), kwargs={}, output_view=None, python_kernel_name=extern_kernels.mm, cpp_kernel_name=at::mm_out, ordered_kwargs_for_cpp_kernel=(), op_overload=None, arg_properties=[{}, {}], allarg_properties={}, kwarg_properties=None, unbacked_bindings={}, mutation_outputs=[], origin_node=mm_default_1, origins=OrderedSet([mm_default_1]), stack_traces = {, File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward, x = self.fc1(x), File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward, return F.linear(input, self.weight, self.bias), } ) ``` Test Plan: ``` buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing ``` Rollback Plan: Differential Revision: D78365534 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158576 Approved by: https://github.com/angelayi --- test/inductor/test_provenance_tracing.py | 15 +++- torch/_inductor/compile_fx.py | 20 +++-- torch/_inductor/debug.py | 102 +++++++++++++++-------- torch/_inductor/ir.py | 52 +++++++++++- 4 files changed, 143 insertions(+), 46 deletions(-) diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 1f7cd7a9f2c00..5dee7a4114049 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -10,7 +10,10 @@ import torch from torch._inductor import config -from torch._inductor.debug import create_node_mapping +from torch._inductor.debug import ( + create_mapping_pre_post_grad_nodes, + create_node_mapping_kernel_to_post_grad, +) from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.triton_utils import requires_cuda @@ -386,11 +389,17 @@ def test_create_node_mapping(self): "triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"] } - result = create_node_mapping( + result = create_mapping_pre_post_grad_nodes( pre_grad_graph_id, post_to_pre_grad_nodes_json, - triton_kernel_to_post_grad_json, ) + result = { + **result, + **create_node_mapping_kernel_to_post_grad( + triton_kernel_to_post_grad_json, + ), + } + self.assertEqual( result, { diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index e20ae1d85ae3b..c14f3fd7d534f 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1033,17 +1033,13 @@ def _compile_fx_inner( provenance_info = torch._inductor.debug.dump_inductor_provenance_info() # provenance_info might be None if trace.provenance_tracking is not set if provenance_info: - ( - _, - node_mappings, - ) = provenance_info trace_structured( "artifact", metadata_fn=lambda: { "name": "inductor_provenance_tracking_node_mappings", "encoding": "json", }, - payload_fn=lambda: json.dumps(node_mappings), + payload_fn=lambda: json.dumps(provenance_info), ) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering @@ -1299,8 +1295,13 @@ def codegen_and_compile( }, payload_fn=lambda: json.dumps(provenance_tracking_json), ) + from torch._inductor.debug import create_mapping_pre_post_grad_nodes + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( - provenance_tracking_json + create_mapping_pre_post_grad_nodes( + torch._inductor.debug._pre_grad_graph_id, + provenance_tracking_json, + ) ) metrics_context = get_metrics_context() @@ -2174,6 +2175,13 @@ def compile_fx( ) torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + if config.trace.provenance_tracking: + for node in model_.graph.nodes: + if node.stack_trace: + torch._inductor.debug._inductor_pre_grad_node_stack_trace[ + node.name + ] = node.stack_trace + model_ = _recursive_pre_grad_passes(model_, example_inputs_) trace_structured( "artifact", diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index f21e0be24d54d..23b26765df2b5 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -316,6 +316,7 @@ def enable_aot_logging() -> Iterator[None]: _inductor_post_to_pre_grad_nodes: dict[str, Any] = {} _inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {} _pre_grad_graph_id: Optional[int] = None +_inductor_pre_grad_node_stack_trace: dict[str, str] = {} @contextlib.contextmanager @@ -701,23 +702,18 @@ class TensorMetadataHolder: save_args_cnt = itertools.count() -def create_node_mapping( - pre_grad_graph_id: int, +def create_mapping_pre_post_grad_nodes( + pre_grad_graph_id: Optional[int], post_to_pre_grad_nodes_json: dict[str, Any], - triton_kernel_to_post_grad_json: dict[str, Any], ) -> dict[str, dict[str, Any]]: - """Create bidirectional mappings between: - - - pre_grad graph nodes and post_grad graph code nodes, and vice versa - - triton kernel name and post_grad graph code nodes, and vice versa """ - + Create bidirectional mappings between pre_grad graph nodes + and post_grad graph code nodes, and vice versa. + """ # return a dummy dict if there's any error empty_return: dict[str, dict[str, Any]] = { "preToPost": {}, "postToPre": {}, - "cppCodeToPost": {}, - "postToCppCode": {}, } log.info("Creating node mappings for provenance tracking") @@ -726,12 +722,6 @@ def create_node_mapping( log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") return empty_return - if not isinstance(triton_kernel_to_post_grad_json, dict): - log.error( - "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" - ) - return empty_return - if not isinstance(pre_grad_graph_id, int): log.error("Provenance tacking error: pre_grad_graph_id is not an int") return empty_return @@ -739,17 +729,7 @@ def create_node_mapping( pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) - post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) - try: - for outer_key, node_array in triton_kernel_to_post_grad_json.items(): - if not isinstance(node_array, list): - log.error( - "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" - ) - return empty_return - for curr_node in node_array: - post_to_cpp_code[curr_node].add(outer_key) def check_format(node: dict[str, Any]) -> bool: if not isinstance(node, dict): @@ -799,10 +779,61 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # convert to list because set is not JSON serializable convert_sets_to_lists(pre_to_post) convert_sets_to_lists(post_to_pre) - convert_sets_to_lists(post_to_cpp_code) return { "preToPost": pre_to_post, "postToPre": post_to_pre, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + log.error("Unexpected error in create_node_mapping: %s", e) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + log.error(traceback.format_exc()) + return empty_return + + +def create_node_mapping_kernel_to_post_grad( + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between triton kernel name and post_grad + graph code nodes, and vice versa. + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + log.info("Creating node mappings for provenance tracking") + + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(post_to_cpp_code) + return { "cppCodeToPost": triton_kernel_to_post_grad_json, "postToCppCode": post_to_cpp_code, } @@ -810,37 +841,38 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # Since this is just logging code, it should never interfere with regular # program execution, so we use this try-except to guard against any error log.error("Unexpected error in create_node_mapping: %s", e) - log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) log.error( "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json ) - log.error("pre_grad_graph_id: %s", pre_grad_graph_id) log.error(traceback.format_exc()) return empty_return def dump_inductor_provenance_info( filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", -) -> tuple[dict[str, list[str]], dict[str, Any]]: +) -> dict[str, Any]: global _pre_grad_graph_id global _inductor_post_to_pre_grad_nodes global _inductor_triton_kernel_to_post_grad_node_info - debug_info = _inductor_triton_kernel_to_post_grad_node_info.copy() if config.trace.enabled: with V.debug.fopen(filename, "w") as fd: log.info("Writing provenance tracing debugging info to %s", fd.name) - json.dump(debug_info, fd) + json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) node_mapping = {} if _pre_grad_graph_id: - node_mapping = create_node_mapping( - _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info + node_mapping_kernel = create_node_mapping_kernel_to_post_grad( + _inductor_triton_kernel_to_post_grad_node_info ) + node_mapping = { + **_inductor_post_to_pre_grad_nodes, + **node_mapping_kernel, + } if config.trace.enabled: with V.debug.fopen( "inductor_provenance_tracking_node_mappings.json", "w" ) as fd: json.dump(node_mapping, fd) - return debug_info, node_mapping + return node_mapping def set_kernel_post_grad_provenance_tracing( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1edbb214ae2ad..e0b3481473323 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -541,12 +541,23 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: class IRNode: + """Base class for all intermediate representation (IR) nodes in TorchInductor. + + Note: + This is an abstract base class. Most methods raise NotImplementedError + and must be overridden by concrete subclasses. + """ + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() # NB: These are kinda weird, origins: OrderedSet[Any] = dataclasses.field(init=False) + # traces back to where the IRNode is created in Inductor traceback: Optional[list[str]] = dataclasses.field(init=False) origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) + # trace backs to user model code + # a single IRNode could correspond to multiple lines of code + stack_traces: dict[str, str] = dataclasses.field(init=False) @staticmethod @contextlib.contextmanager @@ -578,12 +589,41 @@ def _post_init_setattr(self, attr: str, value: Any) -> None: object.__setattr__(self, attr, value) def __post_init__(self) -> None: - self._post_init_setattr("origins", OrderedSet(self._current_origins)) + origins = OrderedSet(self._current_origins) + self._post_init_setattr("origins", origins) self._post_init_setattr( "traceback", traceback.format_stack() if config.debug_ir_traceback else None ) self._post_init_setattr("origin_node", None) + # Group nodes by their stack traces to deduplicate + nodes_to_stack_trace = {} + if config.trace.provenance_tracking: + for node in origins: + if node.stack_trace: + # nodes in the backward graph don't have mapping to pre_grad_graph + nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace + else: + if ( + "postToPre" + not in torch._inductor.debug._inductor_post_to_pre_grad_nodes + ): + continue + node_names = torch._inductor.debug._inductor_post_to_pre_grad_nodes[ + "postToPre" + ].get(node.name, None) + if node_names: + for node_name in node_names: + stack_trace = torch._inductor.debug._inductor_pre_grad_node_stack_trace.get( + node_name, None + ) + if stack_trace: + nodes_to_stack_trace["pre_grad+" + node_name] = ( + stack_trace + ) + + self._post_init_setattr("stack_traces", nodes_to_stack_trace) + def get_read_names(self) -> OrderedSet[str]: return OrderedSet(dep.name for dep in self.get_reads()) @@ -601,7 +641,15 @@ def common_repr(self, shorten: bool = True) -> Sequence[str]: if shorten and len(origins) > 64: # this can get *very* long origins = f"{origins[:61]}..." - return [origins] + if not self.stack_traces: + return [origins] + + stack_trace_str = [] + for stack_trace in self.stack_traces.values(): + stack_trace_str.append("stack_traces = {{") + stack_trace_str += stack_trace.split("\n") + stack_trace_str.append("}") + return [origins] + stack_trace_str def str_helper( self, lines: Sequence[object], shorten: bool = True, multiline: bool = True From d8b084312b54e97bdbaf6a178fe2fc628a23243b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 13:30:36 -0700 Subject: [PATCH 229/457] [DTensor] Fix default_strategy and rename for clarity (#158490) Fixes several bugs in the original. - foremost, fixes a serious bug where we returned incorrect strategies by mixing input_specs that were frozen from select_strategy.strategies[0] with output_specs that varied across select_strategy.strategies[0..N] (e.g. we could create a nonsense strategy like input:Shard(0) output(Replicate) for an op like clone - fixes the redistribute costs: they should not actually be 0, they should be the cost of redistributing our single input from another strategy to the current strategy, in our list of output strategies - adds a note, wondering if we should have just literally returned the input strategy instead of creating this new object - Currently, using default_strategy is incorrect becuase it maps 'self' tensor's strategies directly onto 'src' tensor without accounting for the fact that copy_ supports broadcasting a smaller rank tensor into a larger one. Separates out copy_ op from default strategy, adds missing test case, but does not fix the underlying issue with copy_, leaves that for future PR Renames to `propagate_single_input_strategy` since that's more descriptive Pull Request resolved: https://github.com/pytorch/pytorch/pull/158490 Approved by: https://github.com/wanchaol, https://github.com/XilunWu ghstack dependencies: #158495 --- test/distributed/tensor/test_tensor_ops.py | 32 ++++++ torch/distributed/tensor/_ops/_tensor_ops.py | 109 +++++++++++++------ 2 files changed, 108 insertions(+), 33 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 9be582952f367..9140d2f5aae13 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -53,6 +53,38 @@ def test_clone(self): self.assertFalse(cloned_mat is mat) self.assertEqual(cloned_mat.to_local(), mat.to_local()) + @with_comms + def test_copy_(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + src_specs = [[Replicate()], [Shard(0)]] + src_tensor = torch.randn((12, 12)) + + dst_tensor = torch.zeros(12, 12) + dst_specs = [[Replicate()], [Shard(0)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # @pytest.mark.xfail + # @with_comms + # def test_copy_broadcast(self): + # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # src_specs = [[Replicate()], [Shard(0)]] + # src_tensor = torch.randn((12,)) + + # dst_tensor = torch.zeros(12, 12) + # dst_specs = [[Replicate()], [Shard(1)]] + # for dst_spec, src_spec in zip(dst_specs, src_specs): + # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case + # dst_dtensor.copy_(src_dtensor) + # dst_tensor.copy_(src_tensor) + # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + @with_comms def test_contiguous(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 9bdfc90d145d4..fd6621ab75124 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -39,55 +39,98 @@ aten = torch.ops.aten -def default_strategy(op_schema: OpSchema) -> StrategyType: - # Default strategy by default just propagate the first input strategy - select_strategy = op_schema.args_schema[0] - assert isinstance(select_strategy, OpStrategy) - # we create new DTensorSpecs even for default strategy to assure that - # the tensor metas are distinct between the arguments and outputs - input_specs = [] - redistribute_cost = [] - for i in op_schema.args_schema: - input_specs.append( - DTensorSpec( - mesh=select_strategy.mesh, - placements=select_strategy.strategies[0].output_spec.placements, - tensor_meta=select_strategy.strategies[0].output_spec.tensor_meta, +def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: + # For ops with a single tensor input, we perform a 1:1 mapping such that + # for each strategy that the input supports, we create a corresponding strategy. + # Note: this may be a complete waste of work, becuase it should be equivalent to + # `return first_input_strategy` (unless creating a deep copy is important for some reason) + assert len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) == 1, ( + "propagate_single_input_strategy only works for single-tensor-input ops" + ) + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + ], ) - ) - redistribute_cost.append([0.0] * len(select_strategy.strategies)) - - default_strategy = [ - OpSpec( - output_specs=DTensorSpec( - mesh=select_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - for strategy in select_strategy.strategies - ] - return OpStrategy(default_strategy) + for strategy in first_input_strategy.strategies + ] + ) register_op_strategy( [ aten.clone.default, aten.contiguous.default, - aten.copy_.default, aten.detach.default, aten.fill_.Scalar, aten.view.dtype, aten.zero_.default, ] -)(default_strategy) +)(propagate_single_input_strategy) register_op_strategy( aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(default_strategy) +)(propagate_single_input_strategy) + + +@register_op_strategy(aten.copy_.default) +def copy_strategy(op_schema: OpSchema) -> StrategyType: + # TODO: this strategy is incorrect for copy_ in the case that src tensor + # is smaller rank than self tensor. It is possible to select a strategy from self tensor + # that is invalid for dst tensor. + # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we + # may broadcast a new dim to the left or right of 0 when copying. + # + # For now, I just keep copy working essentially the way it was before this PR, + # but split it out so it can be handled separately in the future. + num_tensor_args = 2 + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + for _ in range(num_tensor_args) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + for _ in range(num_tensor_args) + ], + ) + for strategy in first_input_strategy.strategies + ] + ) @register_op_strategy( From 9a7c2f1f64b1dba1df9ca12249ef659394ffe13d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 04:58:24 +0000 Subject: [PATCH 230/457] Revert "Add torch compile force disable caches alias (#158072)" This reverts commit 2ecf083b7247f265a03ec296ba9d7b795f035118. Reverted https://github.com/pytorch/pytorch/pull/158072 on behalf of https://github.com/jeffdaily due to fails on rocm, signal ignored while rocm was unstable ([comment](https://github.com/pytorch/pytorch/pull/158072#issuecomment-3086740829)) --- docs/source/torch.compiler_troubleshooting_old.md | 2 +- torch/_dynamo/pgo.py | 6 +++--- torch/_functorch/_aot_autograd/autograd_cache.py | 4 ++-- torch/_inductor/config.py | 8 ++++++-- torch/compiler/config.py | 12 ------------ 5 files changed, 12 insertions(+), 20 deletions(-) diff --git a/docs/source/torch.compiler_troubleshooting_old.md b/docs/source/torch.compiler_troubleshooting_old.md index ef13fc1772374..03555d74e817c 100644 --- a/docs/source/torch.compiler_troubleshooting_old.md +++ b/docs/source/torch.compiler_troubleshooting_old.md @@ -717,5 +717,5 @@ backtrace is slow and very spammy so it is not included by default with extended In order to measure the cold start compilation time or debug a cache corruption, it is possible pass `TORCHINDUCTOR_FORCE_DISABLE_CACHES=1` or set -`torch.compiler.config.force_disable_caches = True` which will override any +`torch._inductor.config.force_disable_caches = True` which will override any other caching config option and disable all compile time caching. diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 403187bc6bde8..9bdec2df05c26 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -521,9 +521,9 @@ def process_automatic_dynamic( def get_cache_key() -> Optional[str]: # TODO: info versions of these logs that log only once - if torch.compiler.config.force_disable_caches: + if torch._inductor.config.force_disable_caches: warn_once( - "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches" + "dynamo_pgo force disabled by torch._inductor.config.force_disable_caches" ) return None @@ -566,7 +566,7 @@ def code_state_path(cache_key: str) -> Optional[str]: def should_use_remote_dynamo_pgo_cache() -> bool: - if torch.compiler.config.force_disable_caches: + if torch._inductor.config.force_disable_caches: return False if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index c6a4e11ce81d3..e66ffefe0a00c 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -95,7 +95,7 @@ class FXGraphCacheMiss(BypassAOTAutogradCache): def should_use_remote_autograd_cache(): - if torch.compiler.config.force_disable_caches: + if torch._inductor.config.force_disable_caches: return False if config.enable_remote_autograd_cache is not None: return config.enable_remote_autograd_cache @@ -116,7 +116,7 @@ def should_use_remote_autograd_cache(): def should_use_local_autograd_cache(): - if torch.compiler.config.force_disable_caches: + if torch._inductor.config.force_disable_caches: return False return config.enable_autograd_cache diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5eb2b57a225a4..f1edeb21b4062 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -138,8 +138,12 @@ def prologue_fusion_enabled() -> bool: # None: Not set -- Off for OSS, JustKnobs based for internal bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() -# See torch.compiler.force_disable_caches -force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches") +# Force disabled all inductor level caching -- This will override any other caching flag +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force="TORCHINDUCTOR_FORCE_DISABLE_CACHES", + default=False, +) # Unsafe way to skip dynamic shape guards to get faster cache load unsafe_skip_cache_dynamic_shape_guards: bool = False diff --git a/torch/compiler/config.py b/torch/compiler/config.py index 4009f04e4a0ae..f9ec226c25489 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -66,18 +66,6 @@ A common use case for such a tag is to break caches. """ -force_disable_caches: bool = Config( - justknob="pytorch/remote_cache:force_disable_caches", - env_name_force=[ - "TORCHINDUCTOR_FORCE_DISABLE_CACHES", - "TORCH_COMPILE_FORCE_DISABLE_CACHES", - ], - default=False, -) -""" -Force disables all caching -- This will take precedence over and override any other caching flag -""" - dynamic_sources: str = Config( env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" ) From 9308261a2afb69d807ea06508bb8582b066d9ccd Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 18 Jul 2025 05:02:31 +0000 Subject: [PATCH 231/457] [ROCm][CI] update fbgemm_gpu hash used by inductor tests (#158602) fbgemm_gpu build started failing with asmjit errors. Moving to latest tip of fbgemm for inductor tests resolves the build failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158602 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/pytorch/common_utils.sh | 27 +++++++++++++++++++++++++- .github/ci_commit_pins/fbgemm_rocm.txt | 2 +- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 3dbc2ece9e70b..3de68991bafce 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -176,18 +176,43 @@ function install_torchrec_and_fbgemm() { pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" pip_uninstall fbgemm-gpu-nightly + # Set ROCM_HOME isn't available, use ROCM_PATH if set or /opt/rocm + ROCM_HOME="${ROCM_HOME:-${ROCM_PATH:-/opt/rocm}}" + + # Find rocm_version.h header file for ROCm version extract + rocm_version_h="${ROCM_HOME}/include/rocm-core/rocm_version.h" + if [ ! -f "$rocm_version_h" ]; then + rocm_version_h="${ROCM_HOME}/include/rocm_version.h" + fi + + # Error out if rocm_version.h not found + if [ ! -f "$rocm_version_h" ]; then + echo "Error: rocm_version.h not found in expected locations." >&2 + exit 1 + fi + + # Extract major, minor and patch ROCm version numbers + MAJOR_VERSION=$(grep 'ROCM_VERSION_MAJOR' "$rocm_version_h" | awk '{print $3}') + MINOR_VERSION=$(grep 'ROCM_VERSION_MINOR' "$rocm_version_h" | awk '{print $3}') + PATCH_VERSION=$(grep 'ROCM_VERSION_PATCH' "$rocm_version_h" | awk '{print $3}') + ROCM_INT=$(($MAJOR_VERSION * 10000 + $MINOR_VERSION * 100 + $PATCH_VERSION)) + echo "ROCm version: $ROCM_INT" + export BUILD_ROCM_VERSION="$MAJOR_VERSION.$MINOR_VERSION" + pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm + pushd /tmp git clone --recursive https://github.com/pytorch/fbgemm pushd fbgemm/fbgemm_gpu git checkout "${fbgemm_commit}" python setup.py install \ - --package_variant=rocm \ + --build-variant=rocm \ -DHIP_ROOT_DIR="${ROCM_PATH}" \ -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" popd rm -rf fbgemm + popd else # See https://github.com/pytorch/pytorch/issues/106971 CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" diff --git a/.github/ci_commit_pins/fbgemm_rocm.txt b/.github/ci_commit_pins/fbgemm_rocm.txt index fa11e10ca6b8e..db140a31f3fa4 100644 --- a/.github/ci_commit_pins/fbgemm_rocm.txt +++ b/.github/ci_commit_pins/fbgemm_rocm.txt @@ -1 +1 @@ -5fb5024118e9bb9decf96c2b0b1a8f0010bf56be +7f1de94a4c2d14f59ad4ca84538c36084ea6b2c8 From eb7365072315be2bc4259114e25e269801441748 Mon Sep 17 00:00:00 2001 From: PaliC Date: Thu, 17 Jul 2025 17:01:48 -0700 Subject: [PATCH 232/457] [BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427) This PR is a bit more involved but effectively works to drastically simplify PyObjectSlot and PyInterpreter. 1) For PyObjectSlot we now use a global pyinterpreter since there only is one. From here we change all of the call sites to rely on this assumption. 2) We also remove the "tags" of the PyInterpreter by deprecating `PyInterpreterStatus`. For the reviewer, sadly it seems like `functorch/csrc/dim/dim.cpp` needed to get linted, so there is an unreadable amount of changes there. Fortunately, the only actual change in the file is as follows which just removes `getPyInterpreter()` from the `check_pyobj` call. ``` mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} -} + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158427 Approved by: https://github.com/albanD --- build_variables.bzl | 1 + c10/core/impl/PyInterpreter.h | 20 - c10/core/impl/PyInterpreterHooks.cpp | 32 + c10/core/impl/PyInterpreterHooks.h | 39 + c10/core/impl/PyObjectSlot.cpp | 5 - c10/core/impl/PyObjectSlot.h | 20 +- functorch/csrc/dim/dim.cpp | 6136 ++++++++++++----------- torch/_dynamo/trace_rules.py | 1 - torch/csrc/Module.cpp | 10 +- torch/csrc/PyInterpreter.cpp | 6 +- torch/csrc/PyInterpreter.h | 2 +- torch/csrc/PyInterpreterHooks.cpp | 20 + torch/csrc/PyInterpreterHooks.h | 15 + torch/csrc/Storage.cpp | 47 +- torch/csrc/Storage.h | 1 - torch/csrc/StorageMethods.cpp | 5 +- torch/csrc/StorageSharing.cpp | 18 +- torch/csrc/autograd/python_variable.cpp | 62 +- torch/csrc/utils/python_dispatch.cpp | 18 +- 19 files changed, 3428 insertions(+), 3030 deletions(-) create mode 100644 c10/core/impl/PyInterpreterHooks.cpp create mode 100644 c10/core/impl/PyInterpreterHooks.h create mode 100644 torch/csrc/PyInterpreterHooks.cpp create mode 100644 torch/csrc/PyInterpreterHooks.h diff --git a/build_variables.bzl b/build_variables.bzl index d90f3cfafa3e6..776b1f433fbd0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -863,6 +863,7 @@ libtorch_python_core_sources = [ "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/PyInterpreter.cpp", + "torch/csrc/PyInterpreterHooks.cpp", "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 43492443c530c..09d4801f7d83d 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -240,24 +240,4 @@ struct C10_API PyInterpreter { void disarm() noexcept; }; -// PyInterpreterStatus describes what the state of its interpreter tag -// is, relative to the thread currently holding the GIL. -enum class PyInterpreterStatus { - // We just allocated the Tensor, it hasn't escaped to other threads, - // we know that it definitely hasn't been tagged to be associated - // with an interpreter. - DEFINITELY_UNINITIALIZED, - // We queried the interpreter field and it looked uninitialized. But - // another thread may have raced with us to tag it with some other - // interpreter id. So we will have to do a CEX to make sure we can - // actually nab it. - MAYBE_UNINITIALIZED, - // We queried the interpreter field and it was tagged to belong to us. - // This means we have sole write access (as we hold the GIL for this - // interpreter) - TAGGED_BY_US, - // Someone else tagged this. We can't use this TensorImpl from Python. - TAGGED_BY_OTHER, -}; - } // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.cpp b/c10/core/impl/PyInterpreterHooks.cpp new file mode 100644 index 0000000000000..bd5325cf49c20 --- /dev/null +++ b/c10/core/impl/PyInterpreterHooks.cpp @@ -0,0 +1,32 @@ +#include + +namespace c10::impl { + +// Define the registry +C10_DEFINE_REGISTRY( + PyInterpreterHooksRegistry, + PyInterpreterHooksInterface, + PyInterpreterHooksArgs) + +const PyInterpreterHooksInterface& getPyInterpreterHooks() { + auto create_impl = [] { +#if !defined C10_MOBILE + auto hooks = PyInterpreterHooksRegistry()->Create( + "PyInterpreterHooks", PyInterpreterHooksArgs{}); + if (hooks) { + return hooks; + } +#endif + // Return stub implementation that will throw errors when methods are called + return std::make_unique(); + }; + static auto hooks = create_impl(); + return *hooks; +} + +// Main function to get global PyInterpreter +PyInterpreter* getGlobalPyInterpreter() { + return getPyInterpreterHooks().getPyInterpreter(); +} + +} // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.h b/c10/core/impl/PyInterpreterHooks.h new file mode 100644 index 0000000000000..32a17ad9a8a0c --- /dev/null +++ b/c10/core/impl/PyInterpreterHooks.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10::impl { + +// Minimal interface for PyInterpreter hooks +struct C10_API PyInterpreterHooksInterface { + virtual ~PyInterpreterHooksInterface() = default; + + // Get the PyInterpreter instance + // Stub implementation throws error when Python is not available + virtual PyInterpreter* getPyInterpreter() const { + TORCH_CHECK( + false, + "PyTorch was compiled without Python support. " + "Cannot access Python interpreter from C++."); + } +}; + +struct C10_API PyInterpreterHooksArgs{}; + +C10_DECLARE_REGISTRY( + PyInterpreterHooksRegistry, + PyInterpreterHooksInterface, + PyInterpreterHooksArgs); + +#define REGISTER_PYTHON_HOOKS(clsname) \ + C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname) + +// Get the global PyInterpreter hooks instance +C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks(); + +C10_API PyInterpreter* getGlobalPyInterpreter(); + +} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 62af2eae8e37a..0f1bfb2110747 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -34,11 +34,6 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { reinterpret_cast(pyobj_) & ~0x1ULL); } -void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load()); - pyobj_ = nullptr; -} - PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter) { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index af8b9fa4d0ec7..58b2490eba001 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -24,11 +25,9 @@ struct C10_API PyObjectSlot { // // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after // PyObject if necessary! - void init_pyobj( - PyInterpreter* self_interpreter, - PyObject* pyobj, - PyInterpreterStatus status) { - pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + void init_pyobj(PyObject* pyobj) { + pyobj_interpreter_.store( + getGlobalPyInterpreter(), std::memory_order_relaxed); pyobj_ = pyobj; } @@ -53,9 +52,10 @@ struct C10_API PyObjectSlot { // // NB: this lives in header so that we can avoid actually creating the // std::optional - std::optional check_pyobj( - PyInterpreter* self_interpreter, - bool ignore_hermetic_tls = false) const { + + // @todo alban: I'm not too sure what's going on here, we can probably delete + // it but it's worthwhile making sure + std::optional check_pyobj(bool ignore_hermetic_tls = false) const { impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter == nullptr) { @@ -69,10 +69,6 @@ struct C10_API PyObjectSlot { } } - // Clear the PyObject field for an interpreter, in situations where we - // statically know the tensor is tagged with our interpreter. - void unchecked_clear_pyobj(PyInterpreter* interpreter); - PyInterpreter& load_pyobj_interpreter() const; bool owns_pyobj(); diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 19270d2f9225d..8f1e561e2051b 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -6,7 +6,6 @@ #include - // Many APIs have changed/don't exist anymore #if IS_PYTHON_3_12_PLUS @@ -14,24 +13,25 @@ // Re-enable this some day PyObject* Dim_init() { - PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); - return nullptr; + PyErr_SetString( + PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); + return nullptr; } #else -#include "minpybind.h" #include #include -#include -#include #include +#include +#include #include -//#include -#include +#include "minpybind.h" +// #include +#include #include #include -#include +#include #include #include "arena.h" #include "dim.h" @@ -71,3115 +71,3498 @@ PyTypeObject* DimType = nullptr; PyObject* Tensor_getitem(PyObject* self, PyObject* index); int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); -namespace{ +namespace { void maybeInitializeGlobals() { - // globals that depend on the python dim library, - // which we can't lookup until we finish initializing the _C module - if (_Tensor.ptr()) { - return; - } - auto dim = mpy::import("functorch.dim"); - _Tensor = dim.attr("_Tensor"); - pointwise = dim.attr("pointwise"); - _Tensor_sum = _Tensor.attr("sum"); - DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr(); + // globals that depend on the python dim library, + // which we can't lookup until we finish initializing the _C module + if (_Tensor.ptr()) { + return; + } + auto dim = mpy::import("functorch.dim"); + _Tensor = dim.attr("_Tensor"); + pointwise = dim.attr("pointwise"); + _Tensor_sum = _Tensor.attr("sum"); + DimType = (PyTypeObject*)mpy::import("functorch.dim").attr("Dim").ptr(); } void replaceMappingIfMatches(mpy::handle tp) { - auto T = (PyTypeObject*) tp.ptr(); - bool recurse = false; - if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { - T->tp_as_mapping->mp_subscript = Tensor_getitem; - recurse = true; - } - if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { - T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; - recurse = true; - } - if (recurse) { - auto result = tp.attr("__subclasses__").call(); - mpy::list_view lv(result); - for (auto i : lv.enumerate()) { - replaceMappingIfMatches(lv[i]); - } - } -} - -void initializeGlobals(Arena & A) { - auto torch = mpy::import("torch"); - torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr(); - torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); - - torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); - torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); - torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); - THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; - THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; - NamedTuple = mpy::import("typing").attr("NamedTuple"); - no_slice = PySlice_New(NULL, NULL, NULL); - + auto T = (PyTypeObject*)tp.ptr(); + bool recurse = false; + if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { + T->tp_as_mapping->mp_subscript = Tensor_getitem; + recurse = true; + } + if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { + T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; + recurse = true; + } + if (recurse) { + auto result = tp.attr("__subclasses__").call(); + mpy::list_view lv(result); + for (auto i : lv.enumerate()) { + replaceMappingIfMatches(lv[i]); + } + } +} + +void initializeGlobals(Arena& A) { + auto torch = mpy::import("torch"); + torch_Tensor = (PyTypeObject*)torch.attr("Tensor").ptr(); + torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); + + torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); + torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); + torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + auto TensorBase = (PyTypeObject*)py_TensorBase.ptr(); + THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; + THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; + NamedTuple = mpy::import("typing").attr("NamedTuple"); + no_slice = PySlice_New(NULL, NULL, NULL); } mpy::handle DimensionBindError_; mpy::handle DimensionBindError() { - if(!DimensionBindError_.ptr()) { - DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); - } - return DimensionBindError_; + if (!DimensionBindError_.ptr()) { + DimensionBindError_ = + mpy::import("functorch.dim").attr("DimensionBindError"); + } + return DimensionBindError_; } static int64_t n_dims_created = 65; struct Dim : public mpy::base { - int64_t level_; // for stable comparisons in prototype - mpy::object name_; - Dim() - : level_(n_dims_created++) {} - void init(mpy::object name, int64_t s = -1) { - name_ = std::move(name); - size_ = s; - } - - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == DimType; - } - - int64_t size() const { - if (size_ == -1) { - mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr()); - } - return size_; - } - void set_size(int64_t v) { - if (size_ == -1) { - size_ = v; - } else if(size_ != v) { - mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v); - } - } - bool is_bound() const { - return size_ != -1; - } - static mpy::obj create(mpy::object name, int64_t s = -1) { - if (!DimType) { - maybeInitializeGlobals(); - } - auto r = Dim::alloc(DimType); - r->init(std::move(name), s); - return r; - } - static PyTypeObject Type; - const at::Tensor& range() { - if (!range_.defined()) { - range_ = at::arange(size()); - } - return range_; - } - const at::Tensor& batchtensor() { - if (!batchtensor_.defined()) { - batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); - } - return batchtensor_; - } -private: - int64_t size_{-1}; - at::Tensor range_; - at::Tensor batchtensor_; + int64_t level_; // for stable comparisons in prototype + mpy::object name_; + Dim() : level_(n_dims_created++) {} + void init(mpy::object name, int64_t s = -1) { + name_ = std::move(name); + size_ = s; + } + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == DimType; + } + + int64_t size() const { + if (size_ == -1) { + mpy::raise_error( + PyExc_ValueError, "dimension %S is unbound", name_.ptr()); + } + return size_; + } + void set_size(int64_t v) { + if (size_ == -1) { + size_ = v; + } else if (size_ != v) { + mpy::raise_error( + DimensionBindError(), + "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", + this, + this->size_, + v); + } + } + bool is_bound() const { + return size_ != -1; + } + static mpy::obj create(mpy::object name, int64_t s = -1) { + if (!DimType) { + maybeInitializeGlobals(); + } + auto r = Dim::alloc(DimType); + r->init(std::move(name), s); + return r; + } + static PyTypeObject Type; + const at::Tensor& range() { + if (!range_.defined()) { + range_ = at::arange(size()); + } + return range_; + } + const at::Tensor& batchtensor() { + if (!batchtensor_.defined()) { + batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); + } + return batchtensor_; + } + + private: + int64_t size_{-1}; + at::Tensor range_; + at::Tensor batchtensor_; }; - struct DimEntry { - // union of either a negative number indicating which dimension this is from the rhs, - // or a pointer to a first-class dimension. - // pointers do not have their highest bit set, so checking the number is negative tells us - // that it is not a dim. - bool is_positional() const { - return data_ < 0; - } - bool is_none() const { - return data_ == 0; - } - int64_t position() const { - return data_; - } - mpy::hdl dim() const { - Dim* result; - std::memcpy(&result, &data_, sizeof(Dim*)); - return mpy::hdl(result); - } - - DimEntry() - : data_(0) {} - - DimEntry(int64_t pos) - : data_(pos) { - AT_ASSERT(pos < 0); - } - DimEntry(mpy::hdl d) { - std::memcpy(&data_, &d, sizeof(int64_t)); - } - bool operator==(const DimEntry& rhs) const { - return data_ == rhs.data_; - } -private: - int64_t data_; + // union of either a negative number indicating which dimension this is from + // the rhs, or a pointer to a first-class dimension. pointers do not have + // their highest bit set, so checking the number is negative tells us that it + // is not a dim. + bool is_positional() const { + return data_ < 0; + } + bool is_none() const { + return data_ == 0; + } + int64_t position() const { + return data_; + } + mpy::hdl dim() const { + Dim* result; + std::memcpy(&result, &data_, sizeof(Dim*)); + return mpy::hdl(result); + } + + DimEntry() : data_(0) {} + + DimEntry(int64_t pos) : data_(pos) { + AT_ASSERT(pos < 0); + } + DimEntry(mpy::hdl d) { + std::memcpy(&data_, &d, sizeof(int64_t)); + } + bool operator==(const DimEntry& rhs) const { + return data_ == rhs.data_; + } + + private: + int64_t data_; }; // Dim wrapper methods DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { - if (Dim::check(d)) { - if (keepdim) { - mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True"); - } - return Dim::unchecked_wrap(d); - } else if (mpy::is_int(d)) { - auto i = mpy::to_int(d); - while (i >= 0) { - i -= N; - } - return i; - } else { - return DimEntry(); - } -} - - -int Dim_init(mpy::hdl self, PyObject *args, PyObject *kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"name", "size", nullptr}; - mpy::handle name; - mpy::handle size = nullptr; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast(kwlist), &name, &size)) { - return -1; - } - self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); - return 0; - PY_END(-1) + if (Dim::check(d)) { + if (keepdim) { + mpy::raise_error( + PyExc_ValueError, + "cannot preserve first-class dimensions with keepdim=True"); + } + return Dim::unchecked_wrap(d); + } else if (mpy::is_int(d)) { + auto i = mpy::to_int(d); + while (i >= 0) { + i -= N; + } + return i; + } else { + return DimEntry(); + } +} + +int Dim_init(mpy::hdl self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"name", "size", nullptr}; + mpy::handle name; + mpy::handle size = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|O", const_cast(kwlist), &name, &size)) { + return -1; + } + self->init( + mpy::object::borrow(name), + (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); + return 0; + PY_END(-1) } PyObject* Dim_repr(Dim* self) { - PY_BEGIN - mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string(""); - return name.release(); - PY_END(nullptr) + PY_BEGIN + mpy::object name = (self->name_.ptr()) + ? self->name_ + : mpy::unicode_from_string(""); + return name.release(); + PY_END(nullptr) } - PyObject* Dim_getsize(Dim* self, void*) { - PY_BEGIN - return mpy::from_int(self->size()).release(); - PY_END(nullptr) + PY_BEGIN + return mpy::from_int(self->size()).release(); + PY_END(nullptr) } int Dim_setsize(Dim* self, PyObject* size, void*) { - PY_BEGIN - self->set_size(mpy::to_int(size)); - return 0; - PY_END(-1) + PY_BEGIN + self->set_size(mpy::to_int(size)); + return 0; + PY_END(-1) } PyObject* Dim_getis_bound(Dim* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } PyObject* Dim_getlevel(Dim* self, void*) { - return PyLong_FromLong(self->level_); + return PyLong_FromLong(self->level_); } PyObject* Dim_get_levels(Dim* self, void*) { - mpy::tuple t(1); - t.set(0, mpy::object::borrow(self->ptr())); - return t.release(); + mpy::tuple t(1); + t.set(0, mpy::object::borrow(self->ptr())); + return t.release(); } PyObject* Dim_get_has_device(Dim* self, void*) { - Py_RETURN_FALSE; + Py_RETURN_FALSE; } PyObject* Dim_get_tensor(Dim* self, void*) { - return THPVariable_Wrap(self->range()); + return THPVariable_Wrap(self->range()); } PyObject* Dim_get_batchtensor(Dim* self, void*) { - return THPVariable_Wrap(self->batchtensor()); + return THPVariable_Wrap(self->batchtensor()); } - PyGetSetDef Dim_getsetters[] = { - {"size", (getter) Dim_getsize, (setter) Dim_setsize, - "Dimension size", NULL}, - {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL}, - {"_level", (getter) Dim_getlevel, NULL, "_level", NULL}, - {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL}, - {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL}, - {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL}, - {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL}, - {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL}, - {NULL} /* Sentinel */ -}; + {"size", (getter)Dim_getsize, (setter)Dim_setsize, "Dimension size", NULL}, + {"is_bound", (getter)Dim_getis_bound, NULL, "is_bound", NULL}, + {"_level", (getter)Dim_getlevel, NULL, "_level", NULL}, + {"_levels", (getter)Dim_get_levels, NULL, "_levels", NULL}, + {"_has_device", (getter)Dim_get_has_device, NULL, "_has_device", NULL}, + {"_tensor", (getter)Dim_get_tensor, NULL, "_tensor", NULL}, + {"_batchtensor", (getter)Dim_get_batchtensor, NULL, "_batchtensor", NULL}, + {"ndim", + (getter)[](PyObject* self, void*) + ->PyObject* {return mpy::from_int(1).release(); +} // namespace +, NULL, "ndim", NULL +} +, { + NULL +} /* Sentinel */ +} +; } PyTypeObject Dim::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Dim", /* tp_name */ - sizeof(Dim), /* tp_basicsize */ - 0, /* tp_itemsize */ - Dim::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)Dim_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - "Dim Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - Dim_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)(void*)static_cast,PyObject*,PyObject*)>(Dim_init), /* tp_init */ - 0, /* tp_alloc */ - Dim::new_stub, /* tp_new */ + "_C.Dim", /* tp_name */ + sizeof(Dim), /* tp_basicsize */ + 0, /* tp_itemsize */ + Dim::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)Dim_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Dim Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Dim_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)(void*)static_cast, PyObject*, PyObject*)>( + Dim_init), /* tp_init */ + 0, /* tp_alloc */ + Dim::new_stub, /* tp_new */ }; // class DimList ------------ struct DimList : public mpy::base { - mpy::object name_; - std::vector> dims_; - static PyTypeObject Type; - void init(mpy::object name) { - name_ = std::move(name); - } - void set_dims(std::vector> dims) { - bound_ = true; - dims_ = std::move(dims); - } - bool is_bound() { - return bound_; - } - void bind_len(int64_t size) { - if (bound_) { - int64_t b_size = dims_.size(); - if (b_size != size) { - mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size); - } - } else { - bound_ = true; - dims_.resize(size); - for (Py_ssize_t i = 0; i < size; ++i) { - dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); - } - } - } - int64_t size() const { - if (!bound_) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - return dims_.size(); - } - void set_bound(bool b) { - bound_ = b; - } -private: - bool bound_ = false; + mpy::object name_; + std::vector> dims_; + static PyTypeObject Type; + void init(mpy::object name) { + name_ = std::move(name); + } + void set_dims(std::vector> dims) { + bound_ = true; + dims_ = std::move(dims); + } + bool is_bound() { + return bound_; + } + void bind_len(int64_t size) { + if (bound_) { + int64_t b_size = dims_.size(); + if (b_size != size) { + mpy::raise_error( + DimensionBindError(), + "Dimlist has size %lld but it is being bound to size %d", + b_size, + size); + } + } else { + bound_ = true; + dims_.resize(size); + for (Py_ssize_t i = 0; i < size; ++i) { + dims_[i] = + Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); + } + } + } + int64_t size() const { + if (!bound_) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + return dims_.size(); + } + void set_bound(bool b) { + bound_ = b; + } + + private: + bool bound_ = false; }; - -static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds); static PyObject* DimList_repr(DimList* self) { - PY_BEGIN - if (self->is_bound()) { - size_t size = self->dims_.size(); - mpy::tuple t(size); - for(size_t i = 0; i < size; ++i) { - t.set(i, self->dims_[i]); - } - return mpy::repr(t).release(); - } else if(!mpy::is_none(self->name_)) { - return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); - } else { - return mpy::unicode_from_string("").release(); - } - PY_END(nullptr) -} - -static PyObject* DimList_bind(DimList *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - mpy::handle sizes; - static const char * const _keywords[] = {"sizes", nullptr}; - static _PyArg_Parser parser = {"O", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { - return nullptr; - } - if (!mpy::is_sequence(sizes)) { - mpy::raise_error(PyExc_ValueError, "expected a sequence"); - } - mpy::sequence_view seq = sizes; - auto size = seq.size(); - self->bind_len(size); - for (Py_ssize_t i = 0; i < size; ++i) { - self->dims_[i]->set_size(mpy::to_int(seq[i])); - } - Py_RETURN_NONE; - PY_END(nullptr) -} - -static PyObject* DimList_bind_len(DimList *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - int size; - static const char * const _keywords[] = {"N", nullptr}; - static _PyArg_Parser parser = {"i", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { - return nullptr; - } - self->bind_len(size); - Py_RETURN_NONE; - PY_END(nullptr) + PY_BEGIN + if (self->is_bound()) { + size_t size = self->dims_.size(); + mpy::tuple t(size); + for (size_t i = 0; i < size; ++i) { + t.set(i, self->dims_[i]); + } + return mpy::repr(t).release(); + } else if (!mpy::is_none(self->name_)) { + return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); + } else { + return mpy::unicode_from_string("").release(); + } + PY_END(nullptr) +} + +static PyObject* DimList_bind( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle sizes; + static const char* const _keywords[] = {"sizes", nullptr}; + static _PyArg_Parser parser = {"O", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { + return nullptr; + } + if (!mpy::is_sequence(sizes)) { + mpy::raise_error(PyExc_ValueError, "expected a sequence"); + } + mpy::sequence_view seq = sizes; + auto size = seq.size(); + self->bind_len(size); + for (Py_ssize_t i = 0; i < size; ++i) { + self->dims_[i]->set_size(mpy::to_int(seq[i])); + } + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyObject* DimList_bind_len( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + int size; + static const char* const _keywords[] = {"N", nullptr}; + static _PyArg_Parser parser = {"i", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { + return nullptr; + } + self->bind_len(size); + Py_RETURN_NONE; + PY_END(nullptr) } static PyMethodDef DimList_methods[] = { - {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, - {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"bind", (PyCFunction)(void*)DimList_bind, METH_FASTCALL | METH_KEYWORDS}, + {"bind_len", + (PyCFunction)(void*)DimList_bind_len, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; - static Py_ssize_t DimList_len(DimList* self) { - PY_BEGIN - return self->size(); - PY_END(-1) -} - -static PyObject * DimList_item(DimList* self, Py_ssize_t idx) { - PY_BEGIN - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - if (idx < 0 || (size_t) idx >= self->dims_.size()) { - mpy::raise_error(PyExc_IndexError, "index out of bounds"); - } - mpy::object r = self->dims_[idx]; - return r.release(); - PY_END(nullptr) -} - -PySequenceMethods DimList_seq { - (lenfunc) DimList_len, //lenfunc sq_length; - 0, //binaryfunc sq_concat; - 0, //ssizeargfunc sq_repeat; - (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; - 0, //void *was_sq_slice; - 0, //ssizeobjargproc sq_ass_item; - 0, //void *was_sq_ass_slice; - 0, //objobjproc sq_contains; - - 0, //binaryfunc sq_inplace_concat; - 0, //ssizeargfunc sq_inplace_repeat; + PY_BEGIN + return self->size(); + PY_END(-1) +} + +static PyObject* DimList_item(DimList* self, Py_ssize_t idx) { + PY_BEGIN + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + if (idx < 0 || (size_t)idx >= self->dims_.size()) { + mpy::raise_error(PyExc_IndexError, "index out of bounds"); + } + mpy::object r = self->dims_[idx]; + return r.release(); + PY_END(nullptr) +} + +PySequenceMethods DimList_seq{ + (lenfunc)DimList_len, // lenfunc sq_length; + 0, // binaryfunc sq_concat; + 0, // ssizeargfunc sq_repeat; + (ssizeargfunc)DimList_item, // ssizeargfunc sq_item; + 0, // void *was_sq_slice; + 0, // ssizeobjargproc sq_ass_item; + 0, // void *was_sq_ass_slice; + 0, // objobjproc sq_contains; + + 0, // binaryfunc sq_inplace_concat; + 0, // ssizeargfunc sq_inplace_repeat; }; static PyObject* DimList_getis_bound(DimList* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } static PyGetSetDef DimList_getsetters[] = { - {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL}, - {NULL} /* Sentinel */ + {"is_bound", (getter)DimList_getis_bound, NULL, "is_bound", NULL}, + {NULL} /* Sentinel */ }; - static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { - PY_BEGIN - if (mpy::is_int(idx)) { - return DimList_item(self, mpy::to_int(idx)); - } else if (mpy::is_slice(idx)) { - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - mpy::slice_view s(idx, self->dims_.size()); - mpy::tuple r(s.slicelength); - for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { - r.set(j++, self->dims_[i]); - } - return r.release(); - } else { - mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); - return nullptr; + PY_BEGIN + if (mpy::is_int(idx)) { + return DimList_item(self, mpy::to_int(idx)); + } else if (mpy::is_slice(idx)) { + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + mpy::slice_view s(idx, self->dims_.size()); + mpy::tuple r(s.slicelength); + for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { + r.set(j++, self->dims_[i]); } - PY_END(nullptr) + return r.release(); + } else { + mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); + return nullptr; + } + PY_END(nullptr) } PyMappingMethods DimList_mapping = { - 0, //lenfunc mp_length; - (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; - 0, //objobjargproc mp_ass_subscript; + 0, // lenfunc mp_length; + (binaryfunc)(void*)DimList_subscript, // binaryfunc mp_subscript; + 0, // objobjargproc mp_ass_subscript; }; - - PyTypeObject DimList::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.DimList", /* tp_name */ - sizeof(DimList), /* tp_basicsize */ - 0, /* tp_itemsize */ - DimList::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)DimList_repr, /* tp_repr */ - 0, /* tp_as_number */ - &DimList_seq, /* tp_as_sequence */ - &DimList_mapping, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - 0, /* tp_flags */ - "DimList Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - DimList_methods, /* tp_methods */ - 0, /* tp_members */ - DimList_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc) DimList_init, /* tp_init */ - 0, /* tp_alloc */ - DimList::new_stub, /* tp_new */ + "_C.DimList", /* tp_name */ + sizeof(DimList), /* tp_basicsize */ + 0, /* tp_itemsize */ + DimList::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)DimList_repr, /* tp_repr */ + 0, /* tp_as_number */ + &DimList_seq, /* tp_as_sequence */ + &DimList_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + 0, /* tp_flags */ + "DimList Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DimList_methods, /* tp_methods */ + 0, /* tp_members */ + DimList_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DimList_init, /* tp_init */ + 0, /* tp_alloc */ + DimList::new_stub, /* tp_new */ }; -static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; - mpy::handle len_or_dims = nullptr; - PyObject* name = nullptr; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { - return -1; - } - self->init(mpy::object::borrow(name ? name : Py_None)); - if (len_or_dims.ptr()) { - if(mpy::is_int(len_or_dims)) { - self->bind_len(mpy::to_int(len_or_dims)); - } else if (mpy::is_sequence(len_or_dims)) { - mpy::sequence_view s(len_or_dims); - std::vector> dims; - size_t size = s.size(); - dims.reserve(size); - for (size_t i = 0; i < size; ++i) { - auto r = s[i]; - if (mpy::is_int(r)) { - dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r))); - } else { - dims.emplace_back(Dim::wrap(r)); - } - } - self->set_dims(std::move(dims)); +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; + mpy::handle len_or_dims = nullptr; + PyObject* name = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { + return -1; + } + self->init(mpy::object::borrow(name ? name : Py_None)); + if (len_or_dims.ptr()) { + if (mpy::is_int(len_or_dims)) { + self->bind_len(mpy::to_int(len_or_dims)); + } else if (mpy::is_sequence(len_or_dims)) { + mpy::sequence_view s(len_or_dims); + std::vector> dims; + size_t size = s.size(); + dims.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto r = s[i]; + if (mpy::is_int(r)) { + dims.emplace_back(Dim::create( + mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), + mpy::to_int(r))); } else { - PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions"); - return -1; + dims.emplace_back(Dim::wrap(r)); } - return 0; + } + self->set_dims(std::move(dims)); + } else { + PyErr_Format( + PyExc_ValueError, "expected a length or a sequence of dimensions"); + return -1; } return 0; - PY_END(-1); + } + return 0; + PY_END(-1); } // Tensor ----------------------------- PyTypeObject* TensorType = nullptr; // the python wrapper type. -mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise); +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise); -namespace{ +namespace { at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { - auto levels = Slice(); - levels.extend(A, levels_); - while (true) { - int64_t min_real_index = -1; - int64_t min_index = -1; - int64_t min_value = INT_MAX; - int64_t i = 0; - int64_t r = 0; - for (auto l : levels) { - if (!l.is_none()) { - if (!l.is_positional() && l.dim()->level_ < min_value) { - min_value = l.dim()->level_; - min_index = i; - min_real_index = r; - } - ++i; - } - ++r; - } - if (min_index == -1) { - return t; + auto levels = Slice(); + levels.extend(A, levels_); + while (true) { + int64_t min_real_index = -1; + int64_t min_index = -1; + int64_t min_value = INT_MAX; + int64_t i = 0; + int64_t r = 0; + for (auto l : levels) { + if (!l.is_none()) { + if (!l.is_positional() && l.dim()->level_ < min_value) { + min_value = l.dim()->level_; + min_index = i; + min_real_index = r; } - auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); - t = std::move(t2); - levels[min_real_index] = DimEntry(); + ++i; + } + ++r; + } + if (min_index == -1) { + return t; } + auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); + t = std::move(t2); + levels[min_real_index] = DimEntry(); + } } - - struct DelayedOperator { - DelayedOperator(mpy::object o, mpy::vector_args a) - : orig(std::move(o)), args(a) { - auto all = a.size(); - // this will outlive the call so - // take ownership of temporaries - // in vector args - auto buf = new mpy::handle[all]; - memcpy(buf, args.args, sizeof(mpy::handle)*all); - args.args = buf; - for (auto i : args.enumerate_all()) { - Py_INCREF(args.args[i].ptr()); - } - Py_XINCREF(args.kwnames.ptr()); - } - ~DelayedOperator() { - for (auto i : args.enumerate_all()) { - Py_DECREF(args[i].ptr()); - } - if (args.has_keywords()) { - Py_XDECREF(args.kwnames.ptr()); - } - delete [] args.args; - } - mpy::object orig; - mpy::vector_args args; + DelayedOperator(mpy::object o, mpy::vector_args a) + : orig(std::move(o)), args(a) { + auto all = a.size(); + // this will outlive the call so + // take ownership of temporaries + // in vector args + auto buf = new mpy::handle[all]; + memcpy(buf, args.args, sizeof(mpy::handle) * all); + args.args = buf; + for (auto i : args.enumerate_all()) { + Py_INCREF(args.args[i].ptr()); + } + Py_XINCREF(args.kwnames.ptr()); + } + ~DelayedOperator() { + for (auto i : args.enumerate_all()) { + Py_DECREF(args[i].ptr()); + } + if (args.has_keywords()) { + Py_XDECREF(args.kwnames.ptr()); + } + delete[] args.args; + } + mpy::object orig; + mpy::vector_args args; }; void free_levels_dims(Slice levels) { - for(auto e : levels) { - if (!e.is_positional()) { - mpy::object::steal(e.dim()); - } + for (auto e : levels) { + if (!e.is_positional()) { + mpy::object::steal(e.dim()); } + } } -} +} // namespace struct Tensor : public mpy::base { -private: - at::Tensor tensor_; - at::Tensor batchtensor_; - OwnedSlice levels_; - bool has_device_; - std::unique_ptr delayed_; -public: - - at::Tensor& tensor(Arena& A) { - if (C10_UNLIKELY(!tensor_.defined())) { - AT_ASSERT(delayed_); - auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); - tensor_ = t->tensor(A); - delayed_.reset(); - // don't force creation of batch tensor if it wasn't already provided. - batchtensor_ = t->batchtensor_; - AT_ASSERT(levels() == t->levels()); - } - return tensor_; - } - at::Tensor& batchtensor(Arena& A) { - if (C10_UNLIKELY(!batchtensor_.defined())) { - batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); - } - return batchtensor_; - } - Slice levels() { - return levels_.slice(); - } - bool has_device() { - return has_device_; - } - DelayedOperator* delayed() { - return delayed_.get(); - } - static PyTypeObject Type; - - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == TensorType; - } - - - static mpy::obj create() { - if (!TensorType) { - TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").release(); - } - return Tensor::alloc(TensorType); - } - void capture_levels(Slice levels) { - // grab ownership of the dims inside levels - for (auto l : levels) { - if (!l.is_positional()) { - mpy::object::borrow(l.dim()).release(); - } - } - levels_.set(levels, free_levels_dims); - } - static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device); - static mpy::obj create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device); - friend struct EnableAllLayers; + private: + at::Tensor tensor_; + at::Tensor batchtensor_; + OwnedSlice levels_; + bool has_device_; + std::unique_ptr delayed_; + + public: + at::Tensor& tensor(Arena& A) { + if (C10_UNLIKELY(!tensor_.defined())) { + AT_ASSERT(delayed_); + auto t = Tensor::wrap( + run_torch_function(A, delayed_->orig, delayed_->args, true)); + tensor_ = t->tensor(A); + delayed_.reset(); + // don't force creation of batch tensor if it wasn't already provided. + batchtensor_ = t->batchtensor_; + AT_ASSERT(levels() == t->levels()); + } + return tensor_; + } + at::Tensor& batchtensor(Arena& A) { + if (C10_UNLIKELY(!batchtensor_.defined())) { + batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); + } + return batchtensor_; + } + Slice levels() { + return levels_.slice(); + } + bool has_device() { + return has_device_; + } + DelayedOperator* delayed() { + return delayed_.get(); + } + static PyTypeObject Type; + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == TensorType; + } + + static mpy::obj create() { + if (!TensorType) { + TensorType = + (PyTypeObject*)mpy::import("functorch.dim").attr("Tensor").release(); + } + return Tensor::alloc(TensorType); + } + void capture_levels(Slice levels) { + // grab ownership of the dims inside levels + for (auto l : levels) { + if (!l.is_positional()) { + mpy::object::borrow(l.dim()).release(); + } + } + levels_.set(levels, free_levels_dims); + } + static mpy::object from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device); + static mpy::obj create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device); + friend struct EnableAllLayers; }; -namespace{ +namespace { // version in header does a unnecessary refcount +/- -at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { - if (at::functorch::isBatchedTensor(tensor)) { - return static_cast(tensor.unsafeGetTensorImpl()); - } - return nullptr; +at::functorch::BatchedTensorImpl* maybeGetBatchedImpl( + const at::Tensor& tensor) { + if (at::functorch::isBatchedTensor(tensor)) { + return static_cast( + tensor.unsafeGetTensorImpl()); + } + return nullptr; } TensorRef unchecked_tensor_from(mpy::handle p) { - auto v = (THPVariable*) p.ptr(); - return TensorRef(*v->cdata); + auto v = (THPVariable*)p.ptr(); + return TensorRef(*v->cdata); } static int64_t ndim_of_levels(Slice levels) { - int64_t r = 0; - for (auto l : levels) { - if (l.is_positional()) { - ++r; - } + int64_t r = 0; + for (auto l : levels) { + if (l.is_positional()) { + ++r; } - return r; + } + return r; } struct TensorInfo { - TensorRef tensor; - Slice levels; - bool has_device; - TensorRef batchedtensor; - int64_t ndim() const { - return ndim_of_levels(levels); - } - operator bool() const { - return tensor; - } - - static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) { - if (Tensor::check_exact(h)) { - auto t = Tensor::unchecked_wrap(h); - return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; - } else if (Dim::check_exact(h)) { - auto d = Dim::unchecked_wrap(h); - return TensorInfo {d->range(), Slice(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; - } else if (THPVariable_Check(h.ptr())) { - TensorRef t = unchecked_tensor_from(h); - Slice levels; - for (auto i : irange(-t->dim(), 0)) { - levels.append(A, i); - } - return TensorInfo {t, levels, true, t}; - } else { - if (ensure_present) { - mpy::raise_error(PyExc_ValueError, "expected a tensor object"); - } - return TensorInfo {}; - } + TensorRef tensor; + Slice levels; + bool has_device; + TensorRef batchedtensor; + int64_t ndim() const { + return ndim_of_levels(levels); + } + operator bool() const { + return tensor; + } + + static TensorInfo create( + Arena& A, + mpy::handle h, + bool ensure_batched = true, + bool ensure_present = true) { + if (Tensor::check_exact(h)) { + auto t = Tensor::unchecked_wrap(h); + return TensorInfo{ + t->tensor(A), + t->levels(), + t->has_device(), + ensure_batched ? t->batchtensor(A) : TensorRef()}; + } else if (Dim::check_exact(h)) { + auto d = Dim::unchecked_wrap(h); + return TensorInfo{ + d->range(), + Slice(A, DimEntry(d)), + false, + ensure_batched ? d->batchtensor() : TensorRef()}; + } else if (THPVariable_Check(h.ptr())) { + TensorRef t = unchecked_tensor_from(h); + Slice levels; + for (auto i : irange(-t->dim(), 0)) { + levels.append(A, i); + } + return TensorInfo{t, levels, true, t}; + } else { + if (ensure_present) { + mpy::raise_error(PyExc_ValueError, "expected a tensor object"); + } + return TensorInfo{}; } - - + } }; -static PyObject* py_Tensor_from_positional(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) - MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) - #undef ARGS - - if (!THPVariable_Check(tensor.ptr())) { - mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); - } - - Slice levels; - mpy::sequence_view sq(py_levels); - for (auto i : sq.enumerate()) { - mpy::object v = sq[i]; - if (mpy::is_int(v)) { - auto vi = mpy::to_int(v); - levels.append(A, vi); - } else { - auto dim = Dim::wrap(std::move(v)); - mpy::hdl hdim = dim; - levels.append(A, hdim); - } - } - return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); - PY_END(nullptr) -} -} - -mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device) { - size_t seen_dims = 0; - int last = 0; - //auto sz = tensor.sizes(); - for (auto i : levels.enumerate()) { - auto l = levels[i]; - if (l.is_positional()) { - AT_ASSERT(last == 0 || last + 1 == l.position()); - last = l.position(); - } else { - mpy::object::borrow(l.dim()).release(); - //AT_ASSERT(sz[i] == l.dim()->size()); - ++seen_dims; - } - } - AT_ASSERT(last == 0 || last == -1); - if (!seen_dims) { - return mpy::object::steal(THPVariable_Wrap(tensor)); - } - - mpy::obj self = Tensor::create(); - self->tensor_ = std::move(tensor); - AT_ASSERT(self->tensor_.dim() == levels.size()); - self->levels_.set(levels, free_levels_dims); - self->has_device_ = has_device; - mpy::object r = std::move(self); - return r; -} - - -mpy::obj Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device) { - mpy::obj self = Tensor::create(); - self->capture_levels(levels); - self->has_device_ = has_device; - self->delayed_ = std::make_unique(std::move(op), args); - return self; -} - -namespace{ +static PyObject* py_Tensor_from_positional( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN +#define ARGS(_) \ + _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) + MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) +#undef ARGS + + if (!THPVariable_Check(tensor.ptr())) { + mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); + } + + Slice levels; + mpy::sequence_view sq(py_levels); + for (auto i : sq.enumerate()) { + mpy::object v = sq[i]; + if (mpy::is_int(v)) { + auto vi = mpy::to_int(v); + levels.append(A, vi); + } else { + auto dim = Dim::wrap(std::move(v)); + mpy::hdl hdim = dim; + levels.append(A, hdim); + } + } + return Tensor::from_positional( + A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0) + .release(); + PY_END(nullptr) +} +} // namespace + +mpy::object Tensor::from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device) { + size_t seen_dims = 0; + int last = 0; + // auto sz = tensor.sizes(); + for (auto i : levels.enumerate()) { + auto l = levels[i]; + if (l.is_positional()) { + AT_ASSERT(last == 0 || last + 1 == l.position()); + last = l.position(); + } else { + mpy::object::borrow(l.dim()).release(); + // AT_ASSERT(sz[i] == l.dim()->size()); + ++seen_dims; + } + } + AT_ASSERT(last == 0 || last == -1); + if (!seen_dims) { + return mpy::object::steal(THPVariable_Wrap(tensor)); + } + + mpy::obj self = Tensor::create(); + self->tensor_ = std::move(tensor); + AT_ASSERT(self->tensor_.dim() == levels.size()); + self->levels_.set(levels, free_levels_dims); + self->has_device_ = has_device; + mpy::object r = std::move(self); + return r; +} + +mpy::obj Tensor::create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device) { + mpy::obj self = Tensor::create(); + self->capture_levels(levels); + self->has_device_ = has_device; + self->delayed_ = std::make_unique(std::move(op), args); + return self; +} + +namespace { mpy::list slice_to_list(Slice h) { - mpy::list lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::list lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } mpy::tuple slice_to_tuple(Slice h) { - mpy::tuple lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::tuple lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } enum UType { - U_ELEM, - U_TUPLE_LIKE, - U_DICT, + U_ELEM, + U_TUPLE_LIKE, + U_DICT, }; struct Unflatten { - mpy::object operator()(Slice& elements) { - mpy::object r; - switch (type) { - case U_ELEM: { - r = mpy::object::borrow(elements[0]); - elements = elements.slice(1); - } break; - case U_TUPLE_LIKE: { - mpy::tuple tup(children.size()); - for (auto i : children.enumerate()) { - tup.set(i, children[i](elements)); - } - r = obj.call(tup); - } break; - case U_DICT: { - r = mpy::object::checked_steal(PyDict_New()); - mpy::dict_view rv(r); - mpy::dict_view d(obj); - Py_ssize_t pos = 0; - mpy::handle k, v; - for (int i = 0; d.next(&pos, &k, &v); ++i) { - rv.set(k, children[i](elements)); - } - } break; + mpy::object operator()(Slice& elements) { + mpy::object r; + switch (type) { + case U_ELEM: { + r = mpy::object::borrow(elements[0]); + elements = elements.slice(1); + } break; + case U_TUPLE_LIKE: { + mpy::tuple tup(children.size()); + for (auto i : children.enumerate()) { + tup.set(i, children[i](elements)); + } + r = obj.call(tup); + } break; + case U_DICT: { + r = mpy::object::checked_steal(PyDict_New()); + mpy::dict_view rv(r); + mpy::dict_view d(obj); + Py_ssize_t pos = 0; + mpy::handle k, v; + for (int i = 0; d.next(&pos, &k, &v); ++i) { + rv.set(k, children[i](elements)); } - return r; + } break; } - UType type; - mpy::handle obj; - Slice children; + return r; + } + UType type; + mpy::handle obj; + Slice children; }; -Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice& flat_elements) { - Slice c; - UType utype; - mpy::handle obj; - if (mpy::list_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - mpy::list_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::tuple_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - // includes named tuples - mpy::tuple_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::dict_view::check(agg)) { - utype = U_DICT; - mpy::dict_view d(agg); - obj = agg; - Py_ssize_t pos = 0; - mpy::handle k, v; - while (d.next(&pos, &k, &v)) { - c.append(A, tree_flatten(A, v, flat_elements)); - } - } else { - utype = U_ELEM; - flat_elements.append(A, agg); +Unflatten tree_flatten( + Arena& A, + mpy::handle agg, + Slice& flat_elements) { + Slice c; + UType utype; + mpy::handle obj; + if (mpy::list_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + mpy::list_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::tuple_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + // includes named tuples + mpy::tuple_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::dict_view::check(agg)) { + utype = U_DICT; + mpy::dict_view d(agg); + obj = agg; + Py_ssize_t pos = 0; + mpy::handle k, v; + while (d.next(&pos, &k, &v)) { + c.append(A, tree_flatten(A, v, flat_elements)); } - return Unflatten {utype, obj, c}; + } else { + utype = U_ELEM; + flat_elements.append(A, agg); + } + return Unflatten{utype, obj, c}; } struct UnflattenVectorArgs { - mpy::vector_args operator()(Arena& A, Slice& elements) { - if (!had_nested) { - auto args = elements.begin(); - elements = Slice(); - return mpy::vector_args(args, nargs, kwnames); - } - Slice args; - for (auto u : children) { - args.append(A, A.autorelease(u(elements))); - } - return mpy::vector_args(args.begin(), nargs, kwnames); - } - Slice children; - Py_ssize_t nargs; - mpy::handle kwnames; - bool had_nested; + mpy::vector_args operator()(Arena& A, Slice& elements) { + if (!had_nested) { + auto args = elements.begin(); + elements = Slice(); + return mpy::vector_args(args, nargs, kwnames); + } + Slice args; + for (auto u : children) { + args.append(A, A.autorelease(u(elements))); + } + return mpy::vector_args(args.begin(), nargs, kwnames); + } + Slice children; + Py_ssize_t nargs; + mpy::handle kwnames; + bool had_nested; }; -UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice& flat_elements) { - UnflattenVectorArgs r; - r.kwnames = args.kwnames; - r.nargs = args.nargs; - r.had_nested = false; - auto N = args.size(); - for(auto i : irange(N)) { - auto typ = Py_TYPE(args[i].ptr()); - // fast checks that this thing isn't something that is nested. - bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; - if (!is_element) { - flat_elements.extend(A, args.args, args.args + i); - for (auto j : irange(i)) { - (void)j; - r.children.append(A, Unflatten {U_ELEM}); - } - for (auto j : irange(i, N)) { - r.children.append(A, tree_flatten(A, args[j], flat_elements)); - if (r.children.back().type != U_ELEM) { - r.had_nested = true; - } - } - return r; - } - } - flat_elements.extend(A, args.args, args.args + N); - return r; +UnflattenVectorArgs tree_flatten( + Arena& A, + mpy::vector_args args, + Slice& flat_elements) { + UnflattenVectorArgs r; + r.kwnames = args.kwnames; + r.nargs = args.nargs; + r.had_nested = false; + auto N = args.size(); + for (auto i : irange(N)) { + auto typ = Py_TYPE(args[i].ptr()); + // fast checks that this thing isn't something that is nested. + bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || + typ == TensorType || typ == DimType; + if (!is_element) { + flat_elements.extend(A, args.args, args.args + i); + for (auto j : irange(i)) { + (void)j; + r.children.append(A, Unflatten{U_ELEM}); + } + for (auto j : irange(i, N)) { + r.children.append(A, tree_flatten(A, args[j], flat_elements)); + if (r.children.back().type != U_ELEM) { + r.had_nested = true; + } + } + return r; + } + } + flat_elements.extend(A, args.args, args.args + N); + return r; } - struct UnflattenArena { - Arena A; - Unflatten unflatten; + Arena A; + Unflatten unflatten; }; -PyObject* py_unflatten(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - #define ARGS(_) _(mpy::handle, ns) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) - #undef ARGS - mpy::sequence_view sv(ns); - // because we do not have a autorelase pool yet... - Arena A; - Slice slice; - mpy::handle Tuple = (PyObject*) &PyTuple_Type; - auto inputs = Tuple.call(ns); - mpy::tuple_view tv(inputs); - for (auto i : tv.enumerate()) { - slice.append(A, tv[i]); - } - auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena"); - auto r = AA->unflatten(slice).release(); - AT_ASSERT(r != nullptr); - return r; - PY_END(nullptr) -} - -PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; - -void free_unflatten_arena(PyObject * pc) { - delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena"); -} - -PyObject* py_tree_flatten(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - #define ARGS(_) _(mpy::handle, tree) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) - #undef ARGS - auto A = new UnflattenArena; - Slice elements; - A->unflatten = tree_flatten(A->A, tree, elements); - auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena)); - auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); - mpy::tuple r(2); - r.set(0, slice_to_list(elements)); - r.set(1, std::move(unflatten)); - return r.release(); - PY_END(nullptr) -} - - - -mpy::object tree_map(Arena& A, const std::function& fn, mpy::handle agg) { - Slice elements; - auto unflatten = tree_flatten(A, agg, elements); - for (auto i : elements.enumerate()) { - elements[i] = fn(elements[i]); - } - return unflatten(elements); +PyObject* py_unflatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, ns) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + mpy::sequence_view sv(ns); + // because we do not have a autorelase pool yet... + Arena A; + Slice slice; + mpy::handle Tuple = (PyObject*)&PyTuple_Type; + auto inputs = Tuple.call(ns); + mpy::tuple_view tv(inputs); + for (auto i : tv.enumerate()) { + slice.append(A, tv[i]); + } + auto AA = (UnflattenArena*)PyCapsule_GetPointer(self, "arena"); + auto r = AA->unflatten(slice).release(); + AT_ASSERT(r != nullptr); + return r; + PY_END(nullptr) +} + +PyMethodDef py_unflatten_def = { + "unflatten", + (PyCFunction)(void*)py_unflatten, + METH_FASTCALL | METH_KEYWORDS}; + +void free_unflatten_arena(PyObject* pc) { + delete (UnflattenArena*)PyCapsule_GetPointer(pc, "arena"); +} + +PyObject* py_tree_flatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, tree) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + auto A = new UnflattenArena; + Slice elements; + A->unflatten = tree_flatten(A->A, tree, elements); + auto cap = mpy::object::checked_steal( + PyCapsule_New(A, "arena", free_unflatten_arena)); + auto unflatten = mpy::object::checked_steal( + PyCFunction_New(&py_unflatten_def, cap.release())); + mpy::tuple r(2); + r.set(0, slice_to_list(elements)); + r.set(1, std::move(unflatten)); + return r.release(); + PY_END(nullptr) +} + +mpy::object tree_map( + Arena& A, + const std::function& fn, + mpy::handle agg) { + Slice elements; + auto unflatten = tree_flatten(A, agg, elements); + for (auto i : elements.enumerate()) { + elements[i] = fn(elements[i]); + } + return unflatten(elements); } // prereq: isinstance(h, _Tensor) int64_t _Tensor_ndim(mpy::handle h) { - if (Tensor::check(h)) { - int64_t r = 0; - for (auto l : Tensor::unchecked_wrap(h)->levels()) { - if (l.is_positional()) { - ++r; - } - } - return r; + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } } - // Dim or DelayedMulTensor - return 0; + return r; + } + // Dim or DelayedMulTensor + return 0; } mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} -} + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} +} // namespace struct EnableAllLayers { - EnableAllLayers(Arena& A, Slice levels) { - std::vector> layers; - layers.reserve(levels.size()); - for (auto l : levels) { - if (!l.is_positional()) { - auto d = l.dim(); - levels_to_dim_.append(A, d); - } - } - std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl lhs, mpy::hdl rhs) { return lhs->level_ < rhs->level_;}); - - for (auto i : levels_to_dim_.enumerate()) { - auto batch_size = levels_to_dim_[i]->size(); - auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); - if (i == 0) { - levels_start_ = level; - } - } + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort( + levels_to_dim_.begin(), + levels_to_dim_.end(), + [](mpy::hdl lhs, mpy::hdl rhs) { + return lhs->level_ < rhs->level_; + }); + + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer( + at::functorch::TransformType::Vmap, + batch_size, + at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } + } + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT( + at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == + to_remove - i); + } + } + + mpy::obj from_batched( + Arena& A, + at::Tensor batchedtensor, + bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); } - - ~EnableAllLayers() { - auto to_remove = levels_start_ + levels_to_dim_.size() - 1; - for (auto i : levels_to_dim_.enumerate()) { - AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); - } + TensorRef tensor; + at::functorch::BatchedTensorImpl* impl = maybeGetBatchedImpl(batchedtensor); + while (true) { + auto level = impl->level(); + AT_ASSERT( + level >= levels_start_ && + level < levels_start_ + levels_to_dim_.size()); + mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); + levels.insert(A, impl->bdim(), dim); + at::functorch::BatchedTensorImpl* nimpl = + maybeGetBatchedImpl(impl->value()); + if (!nimpl) { + tensor = impl->value(); + break; + } + impl = nimpl; } - mpy::obj from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { - Slice levels; - for (auto i : irange(-batchedtensor.dim(), 0)) { - levels.append(A, i); - } - TensorRef tensor; - at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); - while(true) { - auto level = impl->level(); - AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); - mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); - levels.insert(A, impl->bdim(), dim); - at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); - if (!nimpl) { - tensor = impl->value(); - break; - } - impl = nimpl; - } - - mpy::obj self = Tensor::create(); - // grab ownership of the tensors - self->tensor_ = *tensor; - self->batchtensor_ = std::move(batchedtensor); - self->has_device_ = has_device; - self->capture_levels(levels); - return self; - } - void inplace_update_layers(TensorRef batchtensor, Slice levels) { - // XXX - requires a patch to functorch to att set_level - auto impl = maybeGetBatchedImpl(*batchtensor); - for (auto i : levels_to_dim_.reversed_enumerate()) { - if (!impl) { - break; - } - if (levels.contains(levels_to_dim_[i])) { - impl->_unsafe_set_level(levels_start_ + i); - impl = maybeGetBatchedImpl(impl->value()); - - } - } - } -private: - int64_t levels_start_{}; - Slice> levels_to_dim_; + mpy::obj self = Tensor::create(); + // grab ownership of the tensors + self->tensor_ = *tensor; + self->batchtensor_ = std::move(batchedtensor); + self->has_device_ = has_device; + self->capture_levels(levels); + return self; + } + void inplace_update_layers(TensorRef batchtensor, Slice levels) { + // XXX - requires a patch to functorch to att set_level + auto impl = maybeGetBatchedImpl(*batchtensor); + for (auto i : levels_to_dim_.reversed_enumerate()) { + if (!impl) { + break; + } + if (levels.contains(levels_to_dim_[i])) { + impl->_unsafe_set_level(levels_start_ + i); + impl = maybeGetBatchedImpl(impl->value()); + } + } + } + + private: + int64_t levels_start_{}; + Slice> levels_to_dim_; }; -namespace{ -TensorRef _match_levels(Arena& A, TensorRef v, Slice from_levels, Slice to_levels, bool drop_levels=false) { - if (from_levels == to_levels) { - return v; - } - // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. - at::IntArrayRef sz = v->sizes(); - at::IntArrayRef sd = v->strides(); - AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); - Slice nsz; - Slice nsd; - for (auto l : to_levels) { - auto oidx = from_levels.index(l); - if (!oidx) { - nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); - nsd.append(A, 0); - } else { - auto idx = *oidx; - nsz.append(A, sz[idx]); - nsd.append(A, sd[idx]); - } - } - return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); -} -} -mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { - if (!pointwise_optimize) { - is_pointwise = false; - } - // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; - - Slice> all_dims; - Slice flat_args; - auto unflatten_args = tree_flatten(A, args, flat_args); - TensorRef device_holding_tensor; - - Slice infos; - Slice result_levels; - for (auto f : flat_args) { - infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); - if (infos.back()) { - TensorInfo& info = infos.back(); - AT_ASSERT(is_pointwise || info.batchedtensor); - if (!device_holding_tensor && info.has_device) { - device_holding_tensor = infos.back().tensor; - } - for (auto l : info.levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } - } - } - - if (is_pointwise) { - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef tensor = infos[i].tensor; - if (device_holding_tensor && !infos[i].has_device) { - tensor = A.autorelease(tensor->to(device_holding_tensor->device())); - } - auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); - flat_args[i] = handle_from_tensor(A, std::move(ml)); - } - } - - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - - mpy::object result = orig.call_vector(uargs); - - // fast wrap for normal case where operator just returns a tensor. - if (THPVariable_Check(result.ptr())) { - return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); - } - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())){ - return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); - } - return h; - }; - return tree_map(A, wrap, result); +namespace { +TensorRef _match_levels( + Arena& A, + TensorRef v, + Slice from_levels, + Slice to_levels, + bool drop_levels = false) { + if (from_levels == to_levels) { + return v; + } + // drop_levels -> if a dim appears in from_levels but not to_levels, it is + // assumed it has stride 0. + at::IntArrayRef sz = v->sizes(); + at::IntArrayRef sd = v->strides(); + AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); + Slice nsz; + Slice nsd; + for (auto l : to_levels) { + auto oidx = from_levels.index(l); + if (!oidx) { + nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); + nsd.append(A, 0); } else { - // std::cout << orig << " calling functorch...\n"; - // std::cout << "rl: " << result_levels << "\n"; - EnableAllLayers guard(A, result_levels); - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef batched = infos[i].batchedtensor; - if (device_holding_tensor && !infos[i].has_device) { - batched = A.autorelease(batched->to(device_holding_tensor->device())); - } - guard.inplace_update_layers(batched, infos[i].levels); - flat_args[i] = handle_from_tensor(A, batched); - } - } - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - AT_ASSERT(flat_it.size() == 0); - mpy::object result = orig.call_vector(uargs); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); - } - return h; - }; - if (THPVariable_Check(result.ptr())) { - return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); - } - return tree_map(A, wrap, result); - } -} - -namespace{ - -mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { - if (orig == torch_Tensor___mul__) { - AT_ASSERT(args.nargs == 2 && !args.has_keywords()); - auto lhs = args[0]; - auto rhs = args[1]; - if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { - bool has_device = false; - Slice levels; - for (auto i : args.enumerate_positional()) { - auto t = TensorInfo::create(A, args[i], false); - // something like a mask * rhs, which matrix multiplies don't correctly promote - if (!t.tensor->is_floating_point()) { - return run_torch_function(A, orig, args, is_pointwise); - } - has_device = has_device || t.has_device; - for (auto l : t.levels) { - if (!levels.contains(l)) { - levels.append(A, l); - } - } - } - // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; - return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device); - } - } - return run_torch_function(A, orig, args, is_pointwise); -} - -mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) { - auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); - auto pos_n = PyTuple_GET_SIZE(args.ptr()); - if (!kwargs.ptr()) { - return mpy::vector_args(pos_args, pos_n, nullptr); + auto idx = *oidx; + nsz.append(A, sz[idx]); + nsd.append(A, sd[idx]); + } + } + return A.autorelease(v->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + v->storage_offset())); +} +} // namespace +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (!pointwise_optimize) { + is_pointwise = false; + } + // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : + // "functorch") << " " << orig << "\n"; + + Slice> all_dims; + Slice flat_args; + auto unflatten_args = tree_flatten(A, args, flat_args); + TensorRef device_holding_tensor; + + Slice infos; + Slice result_levels; + for (auto f : flat_args) { + infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); + if (infos.back()) { + TensorInfo& info = infos.back(); + AT_ASSERT(is_pointwise || info.batchedtensor); + if (!device_holding_tensor && info.has_device) { + device_holding_tensor = infos.back().tensor; + } + for (auto l : info.levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + } + + if (is_pointwise) { + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef tensor = infos[i].tensor; + if (device_holding_tensor && !infos[i].has_device) { + tensor = A.autorelease(tensor->to(device_holding_tensor->device())); + } + auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); + flat_args[i] = handle_from_tensor(A, std::move(ml)); + } + } + + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + + mpy::object result = orig.call_vector(uargs); + + // fast wrap for normal case where operator just returns a tensor. + if (THPVariable_Check(result.ptr())) { + return Tensor::from_positional( + A, + THPVariable_Unpack(result.ptr()), + result_levels, + device_holding_tensor); } - Slice all_args; - Slice kwnames; - all_args.extend(A, pos_args, pos_args + pos_n); - mpy::dict_view dv(kwargs); - Py_ssize_t pos = 0; - mpy::handle key, value; - while (dv.next(&pos, &key, &value)) { - all_args.append(A, value); - kwnames.append(A, key); - } - return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); -} - -PyObject* py___torch_function__(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - AT_ASSERT(nargs == 4 || nargs == 5); - auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); - bool is_pointwise = pointwise.contains(args[1]); - return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); - PY_END(nullptr) + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, + THPVariable_Unpack(h.ptr()), + result_levels, + device_holding_tensor)); + } + return h; + }; + return tree_map(A, wrap, result); + } else { + // std::cout << orig << " calling functorch...\n"; + // std::cout << "rl: " << result_levels << "\n"; + EnableAllLayers guard(A, result_levels); + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef batched = infos[i].batchedtensor; + if (device_holding_tensor && !infos[i].has_device) { + batched = A.autorelease(batched->to(device_holding_tensor->device())); + } + guard.inplace_update_layers(batched, infos[i].levels); + flat_args[i] = handle_from_tensor(A, batched); + } + } + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + AT_ASSERT(flat_it.size() == 0); + mpy::object result = orig.call_vector(uargs); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(guard.from_batched( + A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); + } + return h; + }; + if (THPVariable_Check(result.ptr())) { + return guard.from_batched( + A, THPVariable_Unpack(result.ptr()), device_holding_tensor); + } + return tree_map(A, wrap, result); + } +} + +namespace { + +mpy::object __torch_function__( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (orig == torch_Tensor___mul__) { + AT_ASSERT(args.nargs == 2 && !args.has_keywords()); + auto lhs = args[0]; + auto rhs = args[1]; + if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && + _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { + bool has_device = false; + Slice levels; + for (auto i : args.enumerate_positional()) { + auto t = TensorInfo::create(A, args[i], false); + // something like a mask * rhs, which matrix multiplies don't correctly + // promote + if (!t.tensor->is_floating_point()) { + return run_torch_function(A, orig, args, is_pointwise); + } + has_device = has_device || t.has_device; + for (auto l : t.levels) { + if (!levels.contains(l)) { + levels.append(A, l); + } + } + } + // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; + return Tensor::create_delayed( + mpy::object::borrow(orig), args, levels, has_device); + } + } + return run_torch_function(A, orig, args, is_pointwise); +} + +mpy::vector_args as_vector_args( + Arena& A, + mpy::handle args, + mpy::handle kwargs) { + auto pos_args = (mpy::handle*)&PyTuple_GET_ITEM(args.ptr(), 0); + auto pos_n = PyTuple_GET_SIZE(args.ptr()); + if (!kwargs.ptr()) { + return mpy::vector_args(pos_args, pos_n, nullptr); + } + Slice all_args; + Slice kwnames; + all_args.extend(A, pos_args, pos_args + pos_n); + mpy::dict_view dv(kwargs); + Py_ssize_t pos = 0; + mpy::handle key, value; + while (dv.next(&pos, &key, &value)) { + all_args.append(A, value); + kwnames.append(A, key); + } + return mpy::vector_args( + all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); +} + +PyObject* py___torch_function__( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + AT_ASSERT(nargs == 4 || nargs == 5); + auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); + bool is_pointwise = pointwise.contains(args[1]); + return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); + PY_END(nullptr) } mpy::object levels_to_tuple(Slice slice) { - mpy::tuple t(slice.size()); - for (auto i : slice.enumerate()) { - t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim())); - } - mpy::object r = std::move(t); - return r; + mpy::tuple t(slice.size()); + for (auto i : slice.enumerate()) { + t.set( + i, + slice[i].is_positional() ? mpy::from_int(slice[i].position()) + : mpy::object::borrow(slice[i].dim())); + } + mpy::object r = std::move(t); + return r; } PyObject* Tensor_ndim(Tensor* self, void*) { - Py_ssize_t i = 0; - for (auto l : self->levels()) { - if (l.is_positional()) { - ++i; - } + Py_ssize_t i = 0; + for (auto l : self->levels()) { + if (l.is_positional()) { + ++i; } - return mpy::from_int(i).release(); + } + return mpy::from_int(i).release(); } PyGetSetDef Tensor_getsetters[] = { - {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, - {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* { - Arena A; - return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, - {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* { - Arena A; - return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, - {"_levels", (getter) [](PyObject* self, void*) -> PyObject* { - PY_BEGIN - return levels_to_tuple(((Tensor*)self)->levels()).release(); - PY_END(nullptr) - }}, - {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL}, - {NULL} /* Sentinel */ -}; + {"_has_device", + (getter)[](PyObject* self, void*) + ->PyObject* { + return mpy::from_bool(((Tensor*)self)->has_device()).release(); +} // namespace +, NULL +} +, {"_tensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->tensor(A)); +} +, NULL +} +, {"_batchtensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); +} +, NULL +} +, + {"_levels", + (getter)[](PyObject* self, void*) + ->PyObject* {PY_BEGIN return levels_to_tuple(((Tensor*)self)->levels()) + .release(); +PY_END(nullptr) +} +} +, {"ndim", (getter)Tensor_ndim, NULL, "ndim", NULL}, { + NULL +} /* Sentinel */ +} +; PyMethodDef Tensor_methods[] = { - {NULL, NULL, 0, NULL} /* Sentinel */ + {NULL, NULL, 0, NULL} /* Sentinel */ }; } - PyTypeObject Tensor::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Tensor", /* tp_name */ - sizeof(Tensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - Tensor::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ - "Tensor Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - Tensor_methods, /* tp_methods */ - 0, /* tp_members */ - Tensor_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - Tensor::new_stub, /* tp_new */ + "_C.Tensor", /* tp_name */ + sizeof(Tensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + Tensor::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Tensor Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tensor_methods, /* tp_methods */ + 0, /* tp_members */ + Tensor_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + Tensor::new_stub, /* tp_new */ }; - // dim() -------------------- static bool relevant_op(_Py_CODEUNIT c) { - switch(c) { - case STORE_NAME: - case STORE_GLOBAL: - case STORE_FAST: - case STORE_DEREF: - return true; - default: - return false; - } + switch (c) { + case STORE_NAME: + case STORE_GLOBAL: + case STORE_FAST: + case STORE_DEREF: + return true; + default: + return false; + } } static mpy::object create_dim(mpy::object name, mpy::handle size) { - auto d = Dim::create(std::move(name)); - if (!mpy::is_none(size)) { - d->set_size(mpy::to_int(size)); - } - return std::move(d); + auto d = Dim::create(std::move(name)); + if (!mpy::is_none(size)) { + d->set_size(mpy::to_int(size)); + } + return std::move(d); } static mpy::object create_dimlist(mpy::object name, mpy::handle size) { - auto d = DimList::create(std::move(name)); - if (!mpy::is_none(size)) { - if (mpy::is_int(size)) { - d->bind_len(mpy::to_int(size)); - } else { - mpy::sequence_view s(size); - d->bind_len(s.size()); - for (auto i : irange(d->size())) { - d->dims_[i]->set_size(mpy::to_int(s[i])); - } - } + auto d = DimList::create(std::move(name)); + if (!mpy::is_none(size)) { + if (mpy::is_int(size)) { + d->bind_len(mpy::to_int(size)); + } else { + mpy::sequence_view s(size); + d->bind_len(s.size()); + for (auto i : irange(d->size())) { + d->dims_[i]->set_size(mpy::to_int(s[i])); + } } - return std::move(d); + } + return std::move(d); } - - -// Python wrappers that make new reflection primitives available for older runtimes +// Python wrappers that make new reflection primitives available for older +// runtimes #if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif -namespace{ +namespace { struct PyInstDecoder { - PyInstDecoder(PyCodeObject* code_object, int lasti) - : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} - // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols - // See https://github.com/pytorch/pytorch/issues/93854 - void next() { - #if IS_PYTHON_3_11_PLUS - offset_ += _PyOpcode_Caches[opcode()]; - #endif - offset_ += 1; - } - int opcode() { - auto r = _Py_OPCODE(code_[offset_]); - #if IS_PYTHON_3_11_PLUS - r = _PyOpcode_Deopt[r]; - #endif - return r; - } - int oparg() { - return _Py_OPARG(code_[offset_]); - } - - mpy::object name() { - mpy::object names; - switch(opcode()) { - case STORE_NAME: - case STORE_GLOBAL: - names = mpy::object::borrow(code_object_->co_names); - break; - case STORE_FAST: - names = mpy::object::steal(PyCode_GetVarnames(code_object_)); - break; - case STORE_DEREF: - names = mpy::object::steal(PyCode_GetCellvars(code_object_)); - break; - default: - return mpy::object(); - } - return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); - } -private: - PyCodeObject* code_object_; - _Py_CODEUNIT* code_; - int offset_; + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), + code_(_PyCode_CODE(code_object)), + offset_(lasti / sizeof(_Py_CODEUNIT)) {} + // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols + // See https://github.com/pytorch/pytorch/issues/93854 + void next() { +#if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; +#endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); +#if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; +#endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } + + mpy::object name() { + mpy::object names; + switch (opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = mpy::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = mpy::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = mpy::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return mpy::object(); + } + return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } + + private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; }; -template -static PyObject* _dims(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - Py_ssize_t specified_ndims = -1; - Py_ssize_t found_ndims = 0; - Py_ssize_t sizes = -1; - mpy::handle n = Py_None; - mpy::handle py_sizes = Py_None; - - if (nargs || kwnames) { - mpy::vector_args va(args, nargs, kwnames); - va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); - if (!mpy::is_none(py_sizes)) { - sizes = mpy::sequence_view(py_sizes).size(); - specified_ndims = sizes; - } - if (!mpy::is_none(n)) { - specified_ndims = mpy::to_int(n); - } - } - - PyThreadState* state = PyThreadState_GET(); - auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); - auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); - auto lasti = PyFrame_GetLasti(f.ptr()); - auto decoder = PyInstDecoder(c.ptr(), lasti); - #if IS_PYTHON_3_11_PLUS - // When py3.11 adapts bytecode lasti points to the precall - // rather than the call instruction after it - if (decoder.opcode() == PRECALL) { - decoder.next(); +template +static PyObject* _dims( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Py_ssize_t specified_ndims = -1; + Py_ssize_t found_ndims = 0; + Py_ssize_t sizes = -1; + mpy::handle n = Py_None; + mpy::handle py_sizes = Py_None; + + if (nargs || kwnames) { + mpy::vector_args va(args, nargs, kwnames); + va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); + if (!mpy::is_none(py_sizes)) { + sizes = mpy::sequence_view(py_sizes).size(); + specified_ndims = sizes; } - #endif - decoder.next(); - - if (relevant_op(decoder.opcode())) { - found_ndims = 1; - } else if (decoder.opcode() == UNPACK_SEQUENCE) { - found_ndims = decoder.oparg(); - decoder.next(); + if (!mpy::is_none(n)) { + specified_ndims = mpy::to_int(n); } + } - if (specified_ndims == -1) { - if (found_ndims == 0) { - mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified"); - } - specified_ndims = found_ndims; - } - if (found_ndims != specified_ndims) { - found_ndims = 0; // avoid taking the wrong names for dimensions - } + PyThreadState* state = PyThreadState_GET(); + auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); + auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); +#if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { + decoder.next(); + } +#endif + decoder.next(); - auto genobject = [&](int i) -> mpy::object { - mpy::object name; - if (i < found_ndims) { - name = decoder.name(); - } - if (!name.ptr()) { - name = mpy::unicode_from_format("d%d", i); - found_ndims = 0; // once we fail at finding a name, we can find any more - } else { - decoder.next(); - } - return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); - }; - if (sizes != -1 && sizes != specified_ndims) { - mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes)); - } - if (specified_ndims == 1) { - return genobject(0).release(); - } - mpy::tuple result(specified_ndims); - for (int i = 0; i < specified_ndims; ++i) { - result.set(i, genobject(i)); - } - return result.release(); - PY_END(nullptr) + if (relevant_op(decoder.opcode())) { + found_ndims = 1; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); + } + + if (specified_ndims == -1) { + if (found_ndims == 0) { + mpy::raise_error( + PyExc_SyntaxError, + "dims() must be assigned to a sequence of variable names or have argument n specified"); + } + specified_ndims = found_ndims; + } + if (found_ndims != specified_ndims) { + found_ndims = 0; // avoid taking the wrong names for dimensions + } + + auto genobject = [&](int i) -> mpy::object { + mpy::object name; + if (i < found_ndims) { + name = decoder.name(); + } + if (!name.ptr()) { + name = mpy::unicode_from_format("d%d", i); + found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); + } + return create_object( + std::move(name), + sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); + }; + if (sizes != -1 && sizes != specified_ndims) { + mpy::raise_error( + PyExc_ValueError, + "expected %d sizes but found %d", + int(specified_ndims), + int(sizes)); + } + if (specified_ndims == 1) { + return genobject(0).release(); + } + mpy::tuple result(specified_ndims); + for (int i = 0; i < specified_ndims; ++i) { + result.set(i, genobject(i)); + } + return result.release(); + PY_END(nullptr) } struct DotPart { - Slice dims; - size_t total_size = 1; - void append(Arena& A, mpy::hdl d) { - total_size *= d->size(); - dims.append(A, d); - } + Slice dims; + size_t total_size = 1; + void append(Arena& A, mpy::hdl d) { + total_size *= d->size(); + dims.append(A, d); + } }; -template +template static at::ArrayRef as_array_ref(Slice t) { - return at::ArrayRef(t.begin(), t.end()); -} - -static TensorRef dot_prepare(Arena& A, std::initializer_list parts, const TensorInfo& t) { - Slice new_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - new_levels.extend(A, p.dims); - } - auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); - if (!needs_reshape) { - return r; - } - Slice view; - for (auto p : parts) { - view.append(A, p.total_size); - } - return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); -} - -static mpy::object dot_finish(Arena& A, std::initializer_list parts, at::Tensor r) { - Slice result_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - result_levels.extend(A, p.dims); - } - if (needs_reshape) { - Slice new_size; - for (auto l : result_levels) { - new_size.append(A, l.dim()->size()); - } - r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); - } - return Tensor::from_positional(A, std::move(r), result_levels, true); -} - - - -static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice sum) { - auto lhs_strides = lhs.tensor->strides(); - auto rhs_strides = rhs.tensor->strides(); - - DotPart lro_dims; - DotPart lo_dims; - DotPart ro_dims; - DotPart lr_dims; - - auto insert_dim = [&] (mpy::hdl d, std::optional lhs_idx, std::optional rhs_idx) { - bool reduced = sum.contains(d); - int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; - int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; - if (reduced) { - // lr - lr_dims.append(A, d); - } else { - if ((lhs_stride == 0) == (rhs_stride == 0)) { - // lro - lro_dims.append(A, d); - } else if (lhs_stride != 0) { - // lo - lo_dims.append(A, d); - } else { - AT_ASSERT(rhs_stride != 0); - ro_dims.append(A, d); - } - } - }; - - - auto rhs_seen = A.allocate(rhs.levels.size()); - std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); - - for (auto i : lhs.levels.enumerate()) { - auto d = lhs.levels[i]; - auto rhs_idx = rhs.levels.index(d); - if (rhs_idx) { - rhs_seen[*rhs_idx] = true; - } - insert_dim(d.dim(), i, rhs_idx); - } - - for (auto i : rhs.levels.enumerate()) { - if (rhs_seen[i]) { - continue; - } - auto d = rhs.levels[i]; - insert_dim(d.dim(), std::nullopt, i); - } - - if (lr_dims.dims.size() != sum.size()) { - for (auto & d : sum) { - if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { - mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr()); - } - } - } - - // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; - // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; - - // no batch, just call mm - if (lro_dims.dims.size() != 0) { - auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); - return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + return at::ArrayRef(t.begin(), t.end()); +} + +static TensorRef dot_prepare( + Arena& A, + std::initializer_list parts, + const TensorInfo& t) { + Slice new_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + new_levels.extend(A, p.dims); + } + auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); + if (!needs_reshape) { + return r; + } + Slice view; + for (auto p : parts) { + view.append(A, p.total_size); + } + return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); +} + +static mpy::object dot_finish( + Arena& A, + std::initializer_list parts, + at::Tensor r) { + Slice result_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + result_levels.extend(A, p.dims); + } + if (needs_reshape) { + Slice new_size; + for (auto l : result_levels) { + new_size.append(A, l.dim()->size()); + } + r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + } + return Tensor::from_positional(A, std::move(r), result_levels, true); +} + +static mpy::object dot( + Arena& A, + TensorInfo lhs, + TensorInfo rhs, + Slice sum) { + auto lhs_strides = lhs.tensor->strides(); + auto rhs_strides = rhs.tensor->strides(); + + DotPart lro_dims; + DotPart lo_dims; + DotPart ro_dims; + DotPart lr_dims; + + auto insert_dim = [&](mpy::hdl d, + std::optional lhs_idx, + std::optional rhs_idx) { + bool reduced = sum.contains(d); + int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; + int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; + if (reduced) { + // lr + lr_dims.append(A, d); } else { - auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); - return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); - } - -} - -static PyObject* test_c(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - - Arena A; - Slice s(A, 3, 4, 5); - AT_ASSERT(s.size() == 3 && s.capacity() == 8); - AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); - s.append(A, 6); - AT_ASSERT(s[3] == 6); - for(int i : irange(10)) { - s.append(A, i); - } - AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); - - Slice s2(A, -1, -2, -3); - AT_ASSERT(s2[1] == -2 && s[0] == 3); - - auto ss = s.slice(1,2); - AT_ASSERT(ss.size() == 1); - AT_ASSERT(ss[0] == 4); - AT_ASSERT(ss.capacity() == 1); - ss.append(A, -4); - AT_ASSERT(ss.size() == 2 && ss[1] == -4); - ss[0] = 3; - AT_ASSERT(s[1] == 4); - - s.insert(A, s.slice(1, 4), ss); - AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); - - auto sz = s.size(); - s.insert(A, s.slice(1, 1), 4); - AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); - - - Slice d(A, 0, 1, 2, 3, 4); - - Slice b(A, 0, 1, 2, 3, 4); - b.insert(A, b.slice(1,1), d); - AT_ASSERT(b.size() == 10); - AT_ASSERT(b[1] == 0); - AT_ASSERT(b[5] == 4); - AT_ASSERT(b.back() == 4); - - Py_RETURN_NONE; - - PY_END(nullptr); -} - - -static PyObject* order(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - if (kwnames) { - mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); - } - AT_ASSERT(nargs-- > 0); - Slice orig_levels; - Slice levels; - TensorRef data; - mpy::handle self = args++[0]; - bool has_device; - if (Tensor::check_exact(self)) { - auto t = Tensor::unchecked_wrap(self); - orig_levels = t->levels(); - data = t->tensor(A); - has_device = t->has_device(); + if ((lhs_stride == 0) == (rhs_stride == 0)) { + // lro + lro_dims.append(A, d); + } else if (lhs_stride != 0) { + // lo + lo_dims.append(A, d); + } else { + AT_ASSERT(rhs_stride != 0); + ro_dims.append(A, d); + } + } + }; + + auto rhs_seen = A.allocate(rhs.levels.size()); + std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); + + for (auto i : lhs.levels.enumerate()) { + auto d = lhs.levels[i]; + auto rhs_idx = rhs.levels.index(d); + if (rhs_idx) { + rhs_seen[*rhs_idx] = true; + } + insert_dim(d.dim(), i, rhs_idx); + } + + for (auto i : rhs.levels.enumerate()) { + if (rhs_seen[i]) { + continue; + } + auto d = rhs.levels[i]; + insert_dim(d.dim(), std::nullopt, i); + } + + if (lr_dims.dims.size() != sum.size()) { + for (auto& d : sum) { + if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "summing over non-existent dimension %S", + d.dim().ptr()); + } + } + } + + // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; + // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << + // " " << lr_dims.dims << "\n"; + + // no batch, just call mm + if (lro_dims.dims.size() != 0) { + auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); + return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + } else { + auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); + return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + } +} + +static PyObject* test_c( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + + Arena A; + Slice s(A, 3, 4, 5); + AT_ASSERT(s.size() == 3 && s.capacity() == 8); + AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); + s.append(A, 6); + AT_ASSERT(s[3] == 6); + for (int i : irange(10)) { + s.append(A, i); + } + AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); + + Slice s2(A, -1, -2, -3); + AT_ASSERT(s2[1] == -2 && s[0] == 3); + + auto ss = s.slice(1, 2); + AT_ASSERT(ss.size() == 1); + AT_ASSERT(ss[0] == 4); + AT_ASSERT(ss.capacity() == 1); + ss.append(A, -4); + AT_ASSERT(ss.size() == 2 && ss[1] == -4); + ss[0] = 3; + AT_ASSERT(s[1] == 4); + + s.insert(A, s.slice(1, 4), ss); + AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); + + auto sz = s.size(); + s.insert(A, s.slice(1, 1), 4); + AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + + Slice d(A, 0, 1, 2, 3, 4); + + Slice b(A, 0, 1, 2, 3, 4); + b.insert(A, b.slice(1, 1), d); + AT_ASSERT(b.size() == 10); + AT_ASSERT(b[1] == 0); + AT_ASSERT(b[5] == 4); + AT_ASSERT(b.back() == 4); + + Py_RETURN_NONE; + + PY_END(nullptr); +} + +static PyObject* order( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + if (kwnames) { + mpy::raise_error( + PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } + AT_ASSERT(nargs-- > 0); + Slice orig_levels; + Slice levels; + TensorRef data; + mpy::handle self = args++[0]; + bool has_device; + if (Tensor::check_exact(self)) { + auto t = Tensor::unchecked_wrap(self); + orig_levels = t->levels(); + data = t->tensor(A); + has_device = t->has_device(); + } else { + auto d = Dim::unchecked_wrap(self); + orig_levels.append(A, d); + data = d->range(); + has_device = false; + } + + Slice flat_positional_dims; + Slice> to_flatten; + levels.extend(A, orig_levels); + + int orig_ndim = ndim_of_levels(levels); + auto append = [&](DimEntry d) { + auto midx = levels.index(d); + if (!midx) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_ValueError, + "tensor has %d positional dimensions, but %d specified, or it was specified twice", + int(orig_ndim), + int(d.position() + orig_ndim)); + } else { + mpy::raise_error( + PyExc_ValueError, + "tensor of dimensions %R does not contain dim %R or it was specified twice", + levels_to_tuple(orig_levels).ptr(), + d.dim().ptr()); + } + } + levels[*midx] = DimEntry(); + flat_positional_dims.append(A, d); + }; + + int n_new_positional = 0; + for (auto i : irange(nargs)) { + mpy::handle arg = args[i]; + DimEntry entry = _wrap_dim(arg, orig_ndim, false); + if (!entry.is_none()) { + append(entry); + ++n_new_positional; + } else if (DimList::check(arg)) { + auto dl = DimList::unchecked_wrap(arg); + for (mpy::obj& d : dl->dims_) { + append(mpy::hdl(d)); + ++n_new_positional; + } } else { - auto d = Dim::unchecked_wrap(self); - orig_levels.append(A, d); - data = d->range(); - has_device = false; - } - - Slice flat_positional_dims; - Slice> to_flatten; - levels.extend(A, orig_levels); - - int orig_ndim = ndim_of_levels(levels); - auto append = [&](DimEntry d) { - auto midx = levels.index(d); - if (!midx) { - if (d.is_positional()) { - mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim)); - } else { - mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); - } - } - levels[*midx] = DimEntry(); - flat_positional_dims.append(A, d); - }; - - int n_new_positional = 0; - for (auto i :irange(nargs)) { - mpy::handle arg = args[i]; - DimEntry entry = _wrap_dim(arg, orig_ndim, false); - if (!entry.is_none()) { - append(entry); - ++n_new_positional; - } else if (DimList::check(arg)) { - auto dl = DimList::unchecked_wrap(arg); - for (mpy::obj & d : dl->dims_) { - append(mpy::hdl(d)); - ++n_new_positional; - } - } else { - ++n_new_positional; - if (!mpy::is_sequence(arg)) { - mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); - } - mpy::sequence_view sq(arg); - auto N = sq.size(); - to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); - for (auto j : irange(N)) { - DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); - if (e.is_none()) { - mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); - } - append(e); - } - } - } - - int insert_point = -1; - Slice new_levels; - for (auto l : levels) { - if (l.is_none()) { - continue; - } - if (l.is_positional()) { - if (insert_point == -1) { - insert_point = new_levels.size(); - new_levels.extend(A, flat_positional_dims); - } - } - new_levels.append(A, l); - } - if (insert_point == -1) { + ++n_new_positional; + if (!mpy::is_sequence(arg)) { + mpy::raise_error( + PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); + } + mpy::sequence_view sq(arg); + auto N = sq.size(); + to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); + for (auto j : irange(N)) { + DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); + if (e.is_none()) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); + } + append(e); + } + } + } + + int insert_point = -1; + Slice new_levels; + for (auto l : levels) { + if (l.is_none()) { + continue; + } + if (l.is_positional()) { + if (insert_point == -1) { insert_point = new_levels.size(); new_levels.extend(A, flat_positional_dims); + } } + new_levels.append(A, l); + } + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } - at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); + at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); - if (to_flatten.size()) { - Slice view; - auto sz = ndata.sizes(); - // before the new positional dims - for (auto i : irange(0, insert_point)) { - view.append(A, sz[i]); - } - int i = 0; - for (auto to_flat : to_flatten) { - for (;i < to_flat.first; ++i) { - view.append(A, sz[insert_point + i]); - } - int64_t new_size = 1; - int last = i + to_flat.second; - for (; i < last; ++i) { - new_size *= sz[insert_point + i]; - } - view.append(A, new_size); - } - for (; i < flat_positional_dims.size(); ++i) { - view.append(A, sz[insert_point + i]); - } - // after the new positional dims - for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { - view.append(A, sz[i]); - } - // we shorted the number of dimension, so remove them from new levels - // we will renumber them later - auto n_to_remove = flat_positional_dims.size() - n_new_positional; - new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice()); - ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); - } - - // renumber the positional dimension - int seen = 0; - for (auto i : new_levels.reversed_enumerate()) { - if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { - new_levels[i] = --seen; - } - } - return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); - - PY_END(nullptr) -} - -static PyObject* expand(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs-- > 0); - auto info = TensorInfo::create(A, args++[0], false); - for (auto i : irange(nargs)) { - if (!Dim::check(args[i])) { - maybeInitializeGlobals(); - mpy::vector_args vargs(args - 1, nargs + 1, kwnames); - if (THPVariable_Check(args[-1])) { - return torch_Tensor_expand.call_vector(vargs).release(); - } else { - return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); - } - } - } - const at::Tensor& data = *info.tensor; - auto levels = info.levels; - Slice new_levels; - Slice sz; - Slice sd; - for (auto i : irange(nargs)) { - auto d = Dim::unchecked_wrap(args[i]); - if (levels.contains(d) || new_levels.contains(d)) { - mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr()); - } - new_levels.append(A, d); - sz.append(A, d->size()); - sd.append(A, 0); - } - new_levels.extend(A, levels); - at::IntArrayRef osz = data.sizes(); - at::IntArrayRef osd = data.strides(); - sz.extend(A, osz.begin(), osz.end()); - sd.extend(A, osd.begin(), osd.end()); - at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); - return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); - PY_END(nullptr) -} - - -static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, - Slice> dims, Slice& nsz, Slice& nsd) { - int64_t rhs_prod = 1; - for (auto i : dims.enumerate()) { - if (!dims[i]->is_bound()) { - for (auto j : irange(i + 1, dims.size())) { - if (!dims[j]->is_bound()) { - mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr()); - } - rhs_prod *= dims[j]->size(); - } - if (sz % rhs_prod != 0) { - mpy::tuple tup(dims.size()); - for (auto j : dims.enumerate()) { - tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?")); - } - mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr()); - } - int64_t inferred_size = sz / rhs_prod; - dims[i]->set_size(inferred_size); - rhs_prod = sz; - break; - } - rhs_prod *= dims[i]->size(); - } - if (rhs_prod != sz) { + if (to_flatten.size()) { + Slice view; + auto sz = ndata.sizes(); + // before the new positional dims + for (auto i : irange(0, insert_point)) { + view.append(A, sz[i]); + } + int i = 0; + for (auto to_flat : to_flatten) { + for (; i < to_flat.first; ++i) { + view.append(A, sz[insert_point + i]); + } + int64_t new_size = 1; + int last = i + to_flat.second; + for (; i < last; ++i) { + new_size *= sz[insert_point + i]; + } + view.append(A, new_size); + } + for (; i < flat_positional_dims.size(); ++i) { + view.append(A, sz[insert_point + i]); + } + // after the new positional dims + for (auto i : + irange(insert_point + flat_positional_dims.size(), levels.size())) { + view.append(A, sz[i]); + } + // we shorted the number of dimension, so remove them from new levels + // we will renumber them later + auto n_to_remove = flat_positional_dims.size() - n_new_positional; + new_levels.insert( + A, + new_levels.slice(insert_point, insert_point + n_to_remove), + Slice()); + ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + } + + // renumber the positional dimension + int seen = 0; + for (auto i : new_levels.reversed_enumerate()) { + if (new_levels[i].is_positional() || + (i >= insert_point && i < insert_point + n_new_positional)) { + new_levels[i] = --seen; + } + } + return Tensor::from_positional(A, std::move(ndata), new_levels, has_device) + .release(); + + PY_END(nullptr) +} + +static PyObject* expand( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs-- > 0); + auto info = TensorInfo::create(A, args++[0], false); + for (auto i : irange(nargs)) { + if (!Dim::check(args[i])) { + maybeInitializeGlobals(); + mpy::vector_args vargs(args - 1, nargs + 1, kwnames); + if (THPVariable_Check(args[-1])) { + return torch_Tensor_expand.call_vector(vargs).release(); + } else { + return __torch_function__(A, torch_Tensor_expand, vargs, false) + .release(); + } + } + } + const at::Tensor& data = *info.tensor; + auto levels = info.levels; + Slice new_levels; + Slice sz; + Slice sd; + for (auto i : irange(nargs)) { + auto d = Dim::unchecked_wrap(args[i]); + if (levels.contains(d) || new_levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "expanding dimension %R already exists in tensor with dims", + d.ptr()); + } + new_levels.append(A, d); + sz.append(A, d->size()); + sd.append(A, 0); + } + new_levels.extend(A, levels); + at::IntArrayRef osz = data.sizes(); + at::IntArrayRef osd = data.strides(); + sz.extend(A, osz.begin(), osz.end()); + sd.extend(A, osd.begin(), osd.end()); + at::Tensor ndata = data.as_strided( + at::IntArrayRef(sz.begin(), sz.end()), + at::IntArrayRef(sd.begin(), sd.end()), + data.storage_offset()); + return Tensor::from_positional( + A, std::move(ndata), new_levels, info.has_device) + .release(); + PY_END(nullptr) +} + +static void _bind_dims_to_size( + Arena& A, + int64_t sz, + int64_t sd, + Slice> dims, + Slice& nsz, + Slice& nsd) { + int64_t rhs_prod = 1; + for (auto i : dims.enumerate()) { + if (!dims[i]->is_bound()) { + for (auto j : irange(i + 1, dims.size())) { + if (!dims[j]->is_bound()) { + mpy::raise_error( + DimensionBindError(), + "cannot infer the sizes of two dimensions at once %R and %R", + dims[i].ptr(), + dims[j].ptr()); + } + rhs_prod *= dims[j]->size(); + } + if (sz % rhs_prod != 0) { mpy::tuple tup(dims.size()); for (auto j : dims.enumerate()) { - tup.set(j, mpy::object::borrow(dims[j])); - } - mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr()); - } - auto new_strides = A.allocate(dims.size()); - auto prev_stride = sd; - for (auto i : dims.reversed_enumerate()) { - new_strides[i] = prev_stride; - prev_stride = dims[i]->size()*prev_stride; - } - for (auto i : dims.enumerate()) { - nsd.append(A, new_strides[i]); - nsz.append(A, dims[i]->size()); - } + tup.set( + j, + dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) + : mpy::unicode_from_string("?")); + } + mpy::raise_error( + DimensionBindError(), + "inferred dimension does not evenly fit into larger dimension: %d vs %R", + (int)sz, + tup.ptr()); + } + int64_t inferred_size = sz / rhs_prod; + dims[i]->set_size(inferred_size); + rhs_prod = sz; + break; + } + rhs_prod *= dims[i]->size(); + } + if (rhs_prod != sz) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, mpy::object::borrow(dims[j])); + } + mpy::raise_error( + DimensionBindError(), + "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", + (int)sz, + (int)rhs_prod, + tup.ptr()); + } + auto new_strides = A.allocate(dims.size()); + auto prev_stride = sd; + for (auto i : dims.reversed_enumerate()) { + new_strides[i] = prev_stride; + prev_stride = dims[i]->size() * prev_stride; + } + for (auto i : dims.enumerate()) { + nsd.append(A, new_strides[i]); + nsz.append(A, dims[i]->size()); + } } static bool has_dims(mpy::handle d) { - return Dim::check_exact(d) || Tensor::check_exact(d); + return Dim::check_exact(d) || Tensor::check_exact(d); } struct IndexingInfo { - bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling - bool advanced_indexing; // requires actual lookup - TensorRef self; - Slice flat_inputs; - Slice result_levels; - bool has_device; + bool can_call_original; // if true, then it is safe to just call getitem or + // setitem, these objects do not need special handling + bool advanced_indexing; // requires actual lookup + TensorRef self; + Slice flat_inputs; + Slice result_levels; + bool has_device; }; -} - -IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none); -namespace{ +} // namespace + +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none); +namespace { Slice as_slice(mpy::tuple_view tv) { - PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); - return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); + PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); } Slice as_slice(mpy::list_view tv) { - PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); - return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); -} - - -bool maybe_dimpack(Slice& elements, mpy::handle s, bool check_first=true) { - // can we avoid rechecking? - if (mpy::list_view::check(s)) { - mpy::list_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } - } - // can we avoid rechecking? - if (mpy::tuple_view::check(s)) { - mpy::tuple_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } - } - return false; + PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); +} + +bool maybe_dimpack( + Slice& elements, + mpy::handle s, + bool check_first = true) { + // can we avoid rechecking? + if (mpy::list_view::check(s)) { + mpy::list_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + // can we avoid rechecking? + if (mpy::tuple_view::check(s)) { + mpy::tuple_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + return false; }; bool is_dimpack(mpy::handle s) { - Slice e; - return maybe_dimpack(e, s); + Slice e; + return maybe_dimpack(e, s); } mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { - at::Tensor rtensor; - if (iinfo.advanced_indexing) { - auto self_hdl = handle_from_tensor(A, iinfo.self); - auto tup = slice_to_tuple(iinfo.flat_inputs); - // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; - auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); - rtensor = THPVariable_Unpack(pytensor.ptr()); + at::Tensor rtensor; + if (iinfo.advanced_indexing) { + auto self_hdl = handle_from_tensor(A, iinfo.self); + auto tup = slice_to_tuple(iinfo.flat_inputs); + // std::cout << "calling original getindex " << self_hdl << " " << tup << + // "\n"; + auto pytensor = mpy::object::checked_steal( + THPVariable_getitem(self_hdl.ptr(), tup.ptr())); + rtensor = THPVariable_Unpack(pytensor.ptr()); + } else { + // std::cout << "skipping original getindex\n"; + rtensor = *iinfo.self; + } + // std::cout << "returning (from_positional)\n"; + return Tensor::from_positional( + A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); +} + +mpy::object index( + Arena& A, + mpy::handle self, + mpy::handle dims, + mpy::handle indices) { + maybeInitializeGlobals(); + Slice dims_list; + Slice indices_list; + // we allow for matching single dims to multiple dims, + // so we first have to normalize everything into the case where there is a + // list on lhs and the rhs + bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); + bool rhs_list = + mpy::tuple_view::check(indices) || mpy::list_view::check(indices); + if (lhs_list && rhs_list) { + mpy::sequence_view dv(dims); + mpy::sequence_view ind(indices); + Py_ssize_t N = dv.size(); + if (N != ind.size()) { + mpy::raise_error( + PyExc_TypeError, + "dims (%d) and indices (%d) must have the same length", + int(N), + int(ind.size())); + } + for (auto i : irange(N)) { + dims_list.append(A, A.autorelease(dv[i])); + indices_list.append(A, A.autorelease(ind[i])); + } + } else { + dims_list.append(A, dims); + indices_list.append(A, indices); + } + + // dims being indexed can be grouped together into a single index space, and + // we have to flatten them int a single dimension before we can index them... + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + Slice new_levels; + Slice to_flatten; + Slice dims_list_flat; + auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { + auto d = _wrap_dim(s, ndim, false); + if (d.is_none()) { + mpy::raise_error( + PyExc_TypeError, + "expected a dimension specifyer but found %R", + s.ptr()); + } + return d; + }; + auto dim_not_present = [&](DimEntry d) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_TypeError, + "dimension %d not in tensor of %d dimensions", + d.position() + ndim, + ndim); } else { - // std::cout << "skipping original getindex\n"; - rtensor = *iinfo.self; - } - // std::cout << "returning (from_positional)\n"; - return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); -} - -mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) { - maybeInitializeGlobals(); - Slice dims_list; - Slice indices_list; - // we allow for matching single dims to multiple dims, - // so we first have to normalize everything into the case where there is a list on lhs and the rhs - bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); - bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices); - if (lhs_list && rhs_list) { - mpy::sequence_view dv(dims); - mpy::sequence_view ind(indices); - Py_ssize_t N = dv.size(); - if (N != ind.size()) { - mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size())); - } - for (auto i : irange(N)) { - dims_list.append(A, A.autorelease(dv[i])); - indices_list.append(A, A.autorelease(ind[i])); - } + mpy::raise_error( + PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); + } + }; + + for (auto i : dims_list.enumerate()) { + Slice m; + if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { + if (m.size() == 0) { + // plausible semantics work for this to have 0 elements (e.g. the index + // will always be 0) + dims_list_flat.append(A, DimEntry()); // value is just dropped + } + auto first = parse_dim_entry(m[0]); + dims_list_flat.append(A, first); + if (m.size() == 1) { + continue; + } + if (to_flatten.size() == 0) { + new_levels.extend(A, self_info.levels); + } + Slice rest; + for (auto i : irange(1, m.size())) { + auto d = parse_dim_entry(m[i]); + if (!new_levels.remove(A, d)) { + dim_not_present(d); + } + rest.append(A, d); + } + + auto first_idx = new_levels.index(first); + if (!first_idx) { + dim_not_present(first); + } + new_levels.insert( + A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); + to_flatten.extend(A, rest); } else { - dims_list.append(A, dims); - indices_list.append(A, indices); - } - - // dims being indexed can be grouped together into a single index space, and we have to - // flatten them int a single dimension before we can index them... - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - Slice new_levels; - Slice to_flatten; - Slice dims_list_flat; - auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { - auto d = _wrap_dim(s, ndim, false); - if (d.is_none()) { - mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr()); - } - return d; - }; - auto dim_not_present = [&](DimEntry d) { - if (d.is_positional()) { - mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim); - } else { - mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); - } - }; - - for (auto i : dims_list.enumerate()) { - Slice m; - if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { - if (m.size() == 0) { - // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) - dims_list_flat.append(A, DimEntry()); // value is just dropped - } - auto first = parse_dim_entry(m[0]); - dims_list_flat.append(A, first); - if (m.size() == 1) { - continue; - } - if (to_flatten.size() == 0) { - new_levels.extend(A, self_info.levels); - } - Slice rest; - for (auto i : irange(1, m.size())) { - auto d = parse_dim_entry(m[i]); - if (!new_levels.remove(A, d)) { - dim_not_present(d); - } - rest.append(A, d); - } - - auto first_idx = new_levels.index(first); - if (!first_idx) { - dim_not_present(first); - } - new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); - to_flatten.extend(A, rest); - } else { - dims_list_flat.append(A, parse_dim_entry(dims_list[i])); - } - } - if (to_flatten.size() > 0) { - TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); - at::IntArrayRef sizes = rearranged->sizes(); - Slice new_sizes; - Slice reshape_levels; - for (auto i : new_levels.enumerate()) { - if (to_flatten.contains(new_levels[i])) { - new_sizes.back() *= sizes[i]; - } else { - new_sizes.append(A, sizes[i]); - reshape_levels.append(A, new_levels[i]); - } - } - self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); - - self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op - // we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group - } - bool has_dimpacks = false; - for (auto idx : indices_list) { - if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { - has_dimpacks = true; - break; - } - } - IndexingInfo info = getsetitem_flat(A, self_info, Slice(), dims_list_flat, indices_list, has_dimpacks); - return invoke_getitem(A, info); + dims_list_flat.append(A, parse_dim_entry(dims_list[i])); + } + } + if (to_flatten.size() > 0) { + TensorRef rearranged = + _match_levels(A, self_info.tensor, self_info.levels, new_levels); + at::IntArrayRef sizes = rearranged->sizes(); + Slice new_sizes; + Slice reshape_levels; + for (auto i : new_levels.enumerate()) { + if (to_flatten.contains(new_levels[i])) { + new_sizes.back() *= sizes[i]; + } else { + new_sizes.append(A, sizes[i]); + reshape_levels.append(A, new_levels[i]); + } + } + self_info.tensor = A.autorelease(rearranged->reshape( + at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); + + self_info.levels = + reshape_levels; // note: we are using the first level in a flattened + // group to represent the group for the rest of the op + // we need to be careful not to rely the dimensions size + // because it doesn't match the size of the whole group + } + bool has_dimpacks = false; + for (auto idx : indices_list) { + if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { + has_dimpacks = true; + break; + } + } + IndexingInfo info = getsetitem_flat( + A, + self_info, + Slice(), + dims_list_flat, + indices_list, + has_dimpacks); + return invoke_getitem(A, info); } // true -- the indices were flattened out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { - if (mpy::tuple_view::check(value)) { - return as_slice(mpy::tuple_view(value)); - } else if (mpy::list_view::check(value)) { - return as_slice(mpy::list_view(value)); - } else { - mpy::sequence_view sv(value); - Slice r; - for (auto i : sv.enumerate()) { - r.append(A, A.autorelease(sv[i])); - } - return r; + if (mpy::tuple_view::check(value)) { + return as_slice(mpy::tuple_view(value)); + } else if (mpy::list_view::check(value)) { + return as_slice(mpy::list_view(value)); + } else { + mpy::sequence_view sv(value); + Slice r; + for (auto i : sv.enumerate()) { + r.append(A, A.autorelease(sv[i])); } + return r; + } } bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { - if (mpy::tuple_view::check(index)) { - indices.extend(A, as_slice(mpy::tuple_view(index))); - return true; - } else if (THPVariable_Check(index.ptr())) { - indices.append(A, index); - return false; - } else if (!mpy::is_sequence(index)) { - indices.append(A, index); - return false; - } - // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. - mpy::sequence_view sv(index); - if (sv.size() >= 32) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - for (auto i : sv.enumerate()) { - mpy::handle item; - try { - item = sv[i]; - } catch (mpy::exception_set & e) { - PyErr_Clear(); - indices.append(A, index); - return false; - } - if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - } + if (mpy::tuple_view::check(index)) { + indices.extend(A, as_slice(mpy::tuple_view(index))); + return true; + } else if (THPVariable_Check(index.ptr())) { indices.append(A, index); return false; -} - -IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) { - bool can_call_original_getitem = !tensors_have_dims; - - Slice input; - if (has_dims(index)) { - input.append(A, index); + } else if (!mpy::is_sequence(index)) { + indices.append(A, index); + return false; + } + // a copy of treatSequenceAsTuple modified to add Dim and our wrapped + // tensors.. + mpy::sequence_view sv(index); + if (sv.size() >= 32) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + for (auto i : sv.enumerate()) { + mpy::handle item; + try { + item = sv[i]; + } catch (mpy::exception_set& e) { + PyErr_Clear(); + indices.append(A, index); + return false; + } + if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || + PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || + mpy::is_none(item) || has_dims(item)) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + } + indices.append(A, index); + return false; +} + +IndexingInfo getsetitem( + Arena& A, + mpy::handle self, + mpy::handle index, + bool tensors_have_dims) { + bool can_call_original_getitem = !tensors_have_dims; + + Slice input; + if (has_dims(index)) { + input.append(A, index); + } else { + bool is_sequence = extractIndices(A, index, input); + // nothing about first class dims here, fallback to getitem + if (can_call_original_getitem && !is_sequence) { + return {true}; + } + } + + int64_t dims_indexed = 0; + int64_t expanding_object = -1; + DimList* unbound_dim_list = nullptr; + auto check_expanding = [&](int64_t i) { + if (expanding_object != -1) { + mpy::raise_error( + DimensionBindError(), + "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", + (int)expanding_object, + (int)i); + } + expanding_object = i; + }; + Slice dimlists; + + // calculate how many dimensioned have been indexed in order to compute the + // size of ... or expand a potentially unbound dimension list. + + bool has_dimpacks_or_none = false; + for (auto i : input.enumerate()) { + mpy::handle s = input[i]; + if (Dim::check_exact(s) || Tensor::check_exact(s)) { + can_call_original_getitem = false; + ++dims_indexed; + } else if (s.ptr() == Py_Ellipsis) { + check_expanding(i); + } else if (DimList::check(s)) { + can_call_original_getitem = false; + auto dl = DimList::unchecked_wrap(s); + if (!dl->is_bound()) { + check_expanding(i); + unbound_dim_list = dl.ptr(); + } else { + dims_indexed += dl->dims_.size(); + } + dimlists.append(A, i); + } else if (mpy::is_none(s)) { + has_dimpacks_or_none = true; + } else if (is_dimpack(s)) { + can_call_original_getitem = false; + has_dimpacks_or_none = true; + ++dims_indexed; } else { - bool is_sequence = extractIndices(A, index, input); - // nothing about first class dims here, fallback to getitem - if (can_call_original_getitem && !is_sequence) { - return { true }; - } - } - - int64_t dims_indexed = 0; - int64_t expanding_object = -1; - DimList* unbound_dim_list = nullptr; - auto check_expanding = [&](int64_t i) { - if (expanding_object != -1) { - mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i); - } - expanding_object = i; - }; - Slice dimlists; - - // calculate how many dimensioned have been indexed in order to compute the size of ... - // or expand a potentially unbound dimension list. - - bool has_dimpacks_or_none = false; - for (auto i : input.enumerate()) { - mpy::handle s = input[i]; - if (Dim::check_exact(s) || Tensor::check_exact(s)) { - can_call_original_getitem = false; - ++dims_indexed; - } else if (s.ptr() == Py_Ellipsis) { - check_expanding(i); - } else if (DimList::check(s)) { - can_call_original_getitem = false; - auto dl = DimList::unchecked_wrap(s); - if (!dl->is_bound()) { - check_expanding(i); - unbound_dim_list = dl.ptr(); - } else { - dims_indexed += dl->dims_.size(); - } - dimlists.append(A, i); - } else if (mpy::is_none(s)) { - has_dimpacks_or_none = true; - } else if (is_dimpack(s)) { - can_call_original_getitem = false; - has_dimpacks_or_none = true; - ++dims_indexed; - } else { - ++dims_indexed; - } - } - - // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. - if (can_call_original_getitem) { - return {true}; + ++dims_indexed; + } + } + + // at this point if we haven't seen any Dim objects, we also can fallback to + // the original getitem. + if (can_call_original_getitem) { + return {true}; + } + + // std::cout << "__getitem__ " << self << " " << index << "\n"; + + TensorInfo self_info = TensorInfo::create(A, self, false, true); + auto ndim = self_info.ndim(); + if (dims_indexed > ndim) { + mpy::raise_error( + PyExc_ValueError, + "at least %d indices were supplied but the tensor only has %d dimensions", + (int)dims_indexed, + (int)ndim); + } + // expand any unbound dimension list, or expand ... into individual : slices. + auto expanding_dims = ndim - dims_indexed; + if (expanding_object != -1) { + if (unbound_dim_list) { + unbound_dim_list->bind_len(expanding_dims); + } else { + // ... + Slice no_slices; + for (auto i : irange(expanding_dims)) { + (void)i; + no_slices.append(A, no_slice); + } + input.insert( + A, input.slice(expanding_object, expanding_object + 1), no_slices); + } + } + + // flatten out any dimensions stored in dimlist elements directly into the + // inputs std::cout << dimlists << " <- dim lists!\n"; + for (int64_t i = dimlists.size() - 1; i >= 0; --i) { + auto idx = dimlists[i]; + // we added more elements to input because of ... + // so we need to also adjust the index to get back to where the + // dimlist existed + if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { + idx += expanding_dims; + } + auto dl = DimList::unchecked_wrap(input[idx]); + // XXX would be better if we used an OwnedSlice in DimList + Slice more_dims( + (mpy::handle*)&*dl->dims_.begin(), (mpy::handle*)&*dl->dims_.end()); + input.insert(A, input.slice(idx, idx + 1), more_dims); + } + + return getsetitem_flat( + A, + self_info, + input, + Slice(), + Slice(), + has_dimpacks_or_none); +} +} // namespace +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none) { + // At this point: + // ..., DimList have been eliminated + // Dim, Tensor, Tuple[Dim,...], int, slice still remain + + // we have to count how many times we see a dimension. + // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires + // advanced indexing. + Slice> seen_dims; + Slice seen_dims_nuses; + auto add_dim = [&](mpy::hdl entry) { + auto midx = seen_dims.index(entry); + if (!midx) { + seen_dims.append(A, entry); + seen_dims_nuses.append(A, 1); + } else { + ++seen_dims_nuses[*midx]; } + }; - // std::cout << "__getitem__ " << self << " " << index << "\n"; + Slice input_it = input; - TensorInfo self_info = TensorInfo::create(A, self, false, true); - auto ndim = self_info.ndim(); - if (dims_indexed > ndim) { - mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim); - } - // expand any unbound dimension list, or expand ... into individual : slices. - auto expanding_dims = ndim - dims_indexed; - if (expanding_object != -1) { - if (unbound_dim_list) { - unbound_dim_list->bind_len(expanding_dims); - } else { - // ... - Slice no_slices; - for (auto i : irange(expanding_dims)) { - (void) i; - no_slices.append(A, no_slice); - } - input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); - } - } + Slice flat_inputs; + // flat inputs will start with an empty mpy::handle if the + // actual value is in the tensor-like object in the tensor info + Slice tensor_inputs; - // flatten out any dimensions stored in dimlist elements directly into the inputs - // std::cout << dimlists << " <- dim lists!\n"; - for (int64_t i = dimlists.size() - 1; i >=0; --i) { - auto idx = dimlists[i]; - // we added more elements to input because of ... - // so we need to also adjust the index to get back to where the - // dimlist existed - if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { - idx += expanding_dims; - } - auto dl = DimList::unchecked_wrap(input[idx]); - // XXX would be better if we used an OwnedSlice in DimList - Slice more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end()); - input.insert(A, input.slice(idx, idx + 1), more_dims); + auto append_flat_handle = [&](mpy::handle h) { + flat_inputs.append(A, h); + tensor_inputs.append(A, TensorInfo()); + }; + TensorRef device_holding_tensor; + auto append_tensor_input = [&](TensorInfo ti) { + flat_inputs.append(A, mpy::handle()); + tensor_inputs.append(A, ti); + if (ti.has_device && !device_holding_tensor) { + device_holding_tensor = ti.tensor; } + }; - return getsetitem_flat(A, self_info, input, Slice(), Slice(), has_dimpacks_or_none); -} -} -IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none) { - // At this point: - // ..., DimList have been eliminated - // Dim, Tensor, Tuple[Dim,...], int, slice still remain - - - // we have to count how many times we see a dimension. - // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. - Slice> seen_dims; - Slice seen_dims_nuses; - auto add_dim = [&](mpy::hdl entry) { - auto midx = seen_dims.index(entry); - if (!midx) { - seen_dims.append(A, entry); - seen_dims_nuses.append(A, 1); - } else { - ++seen_dims_nuses[*midx]; - } - }; - - Slice input_it = input; - - Slice flat_inputs; - // flat inputs will start with an empty mpy::handle if the - // actual value is in the tensor-like object in the tensor info - Slice tensor_inputs; - - auto append_flat_handle = [&](mpy::handle h) { - flat_inputs.append(A, h); - tensor_inputs.append(A, TensorInfo()); - }; - TensorRef device_holding_tensor; - auto append_tensor_input = [&](TensorInfo ti) { - flat_inputs.append(A, mpy::handle()); - tensor_inputs.append(A, ti); - if (ti.has_device && !device_holding_tensor) { - device_holding_tensor = ti.tensor; - } - }; - - Slice nsz; - Slice nsd; - at::IntArrayRef sz = self_info.tensor->sizes(); - at::IntArrayRef sd = self_info.tensor->strides(); - - auto append_size = [&](int i) { - if (has_dimpacks_or_none) { - nsz.append(A, sz[i]); - nsd.append(A, sd[i]); - } - }; - // std::cout << "self levels: " << self_info.levels << "\n"; - - auto parse_nones = [&]() { - while (input_it.size() && mpy::is_none(input_it[0])) { - append_flat_handle(no_slice); - nsz.append(A, 1); - nsd.append(A, 0); - input_it = input_it.slice(1); - } - }; - + Slice nsz; + Slice nsd; + at::IntArrayRef sz = self_info.tensor->sizes(); + at::IntArrayRef sd = self_info.tensor->strides(); - auto append_item = [&](int i, mpy::handle arg) { - if (Dim::check_exact(arg)) { - auto d = Dim::unchecked_wrap(arg); - d->set_size(sz[i]); - add_dim(d); - append_size(i); - append_flat_handle(arg); - return; - } - auto info = TensorInfo::create(A, arg, false, false); - if (info) { - append_size(i); - append_tensor_input(info); - for (auto il : info.levels) { - if (!il.is_positional()) { - add_dim(il.dim()); - } - } - return; - } - - if (has_dimpacks_or_none) { - Slice mp; - if (maybe_dimpack(mp, arg)) { - // dim pack - Slice> dim_pack; - for (auto d : mp) { - dim_pack.append(A, Dim::wrap(d)); - add_dim(dim_pack.back()); - append_flat_handle(dim_pack.back()); - } - _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); - return; - } - } - - append_size(i); - append_flat_handle(arg); - }; - - // pair up the indexing expressions with dimension of self it indexes - // self may have first-class dims, which do not participate the indexing. - for (auto i : self_info.levels.enumerate()) { - auto l = self_info.levels[i]; - auto idx = keys.index(l); - if (idx) { - append_item(i, values[*idx]); - } else if (l.is_positional()) { - // grab and index from the positional list - parse_nones(); - if (!input_it.size()) { - // we might have fewer indices than tensor dimensions, - // which implicitly indexes the remaining dimensions with : - append_flat_handle(no_slice); - append_size(i); - } else { - mpy::handle arg = input_it[0]; - input_it = input_it.slice(1); - append_item(i, arg); - } - } else { - add_dim(l.dim()); - append_flat_handle(l.dim()); - append_size(i); - } - } - // any training Nones may have no existing dimension associated with them in self. - parse_nones(); - - // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. + auto append_size = [&](int i) { if (has_dimpacks_or_none) { - self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); - } - - - // figure out what the shape of the indexing tensors will be - // and what the shape of the resulting tensor will be - Slice result_levels; - Slice index_levels; - int64_t tensor_insert_point = -1; - bool requires_getindex = false; - auto mark_tensor_index = [&] { - if (tensor_insert_point == -1) { - tensor_insert_point = result_levels.size(); - } else if (tensor_insert_point != result_levels.size()) { - tensor_insert_point = 0; - } - }; - for (auto i : flat_inputs.enumerate()) { - auto inp = flat_inputs[i]; - if(tensor_inputs[i]) { - requires_getindex = true; - mark_tensor_index(); - for (auto l : tensor_inputs[i].levels) { - // std::cout << "Consider to add " << l << "\n"; - if (!index_levels.contains(l)) { - index_levels.append(A, l); - } - } - } else if (Dim::check_exact(inp)) { - auto d = Dim::unchecked_wrap(inp); - // dimensions used once are just binding operations - if (1 == seen_dims_nuses[*seen_dims.index(d)]) { - flat_inputs[i] = no_slice; - result_levels.append(A, d); - } else { - requires_getindex = true; - flat_inputs[i] = mpy::handle(); - tensor_inputs[i] = TensorInfo {d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; - if (!index_levels.contains(d)) { - index_levels.append(A, d); - } - mark_tensor_index(); - } - } else { - if (inp.ptr() != no_slice.ptr()) { - requires_getindex = true; - } - if (!mpy::is_int(inp)) { - // note: actual positional indexes are accurately computed later - result_levels.append(A, -1); - } - } - } - - // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert - // the indexing leveles into the result klevels at this spot - if (tensor_insert_point != -1) { - result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); - } - - // std::cout << "flat inputs: " << flat_inputs << "\n"; - // std::cout << "result_levels: " << result_levels << "\n"; - // std::cout << "index_levels: " << index_levels << "\n"; - - // get all the tensors to be the right shape for indexing - if (requires_getindex) { - for (auto i : flat_inputs.enumerate()) { - if (tensor_inputs[i]) { - AT_ASSERT(!flat_inputs[i].ptr()); - // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; - TensorRef t = tensor_inputs[i].tensor; - if (!tensor_inputs[i].has_device && device_holding_tensor) { - t = A.autorelease(t->to(device_holding_tensor->device())); - } - flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); - } - } - } - - // previously we didn't know how many positional dimensions there would be so we couldn't number them right - // so fill it in now. - auto seen_positionals = 0; - for (auto i : result_levels.reversed_enumerate()) { - if (result_levels[i].is_positional()) { - result_levels[i] = -(++seen_positionals); - } - } - - return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; -} -namespace{ -mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self)); - if (iinfo.can_call_original) { - return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); + nsz.append(A, sz[i]); + nsd.append(A, sd[i]); + } + }; + // std::cout << "self levels: " << self_info.levels << "\n"; + + auto parse_nones = [&]() { + while (input_it.size() && mpy::is_none(input_it[0])) { + append_flat_handle(no_slice); + nsz.append(A, 1); + nsd.append(A, 0); + input_it = input_it.slice(1); + } + }; + + auto append_item = [&](int i, mpy::handle arg) { + if (Dim::check_exact(arg)) { + auto d = Dim::unchecked_wrap(arg); + d->set_size(sz[i]); + add_dim(d); + append_size(i); + append_flat_handle(arg); + return; + } + auto info = TensorInfo::create(A, arg, false, false); + if (info) { + append_size(i); + append_tensor_input(info); + for (auto il : info.levels) { + if (!il.is_positional()) { + add_dim(il.dim()); + } + } + return; } - return invoke_getitem(A, iinfo); -} - - - -void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); - if (iinfo.can_call_original) { - if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { - throw mpy::exception_set(); - } + if (has_dimpacks_or_none) { + Slice mp; + if (maybe_dimpack(mp, arg)) { + // dim pack + Slice> dim_pack; + for (auto d : mp) { + dim_pack.append(A, Dim::wrap(d)); + add_dim(dim_pack.back()); + append_flat_handle(dim_pack.back()); + } + _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); return; - } - - auto rhs_info = TensorInfo::create(A, rhs, false, false); - if (rhs_info) { // otherwise rhs can be a scalar... - for (auto l : rhs_info.levels) { - if (!iinfo.result_levels.contains(l)) { - if (l.is_positional()) { - mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); - } else { - auto tup = levels_to_tuple(iinfo.result_levels); - mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr()); - } - } + } + } + + append_size(i); + append_flat_handle(arg); + }; + + // pair up the indexing expressions with dimension of self it indexes + // self may have first-class dims, which do not participate the indexing. + for (auto i : self_info.levels.enumerate()) { + auto l = self_info.levels[i]; + auto idx = keys.index(l); + if (idx) { + append_item(i, values[*idx]); + } else if (l.is_positional()) { + // grab and index from the positional list + parse_nones(); + if (!input_it.size()) { + // we might have fewer indices than tensor dimensions, + // which implicitly indexes the remaining dimensions with : + append_flat_handle(no_slice); + append_size(i); + } else { + mpy::handle arg = input_it[0]; + input_it = input_it.slice(1); + append_item(i, arg); + } + } else { + add_dim(l.dim()); + append_flat_handle(l.dim()); + append_size(i); + } + } + // any training Nones may have no existing dimension associated with them in + // self. + parse_nones(); + + // we have to restride the tensor to collapse dimension packs and introduce + // our none dimensions. + if (has_dimpacks_or_none) { + self_info.tensor = A.autorelease(self_info.tensor->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + self_info.tensor->storage_offset())); + } + + // figure out what the shape of the indexing tensors will be + // and what the shape of the resulting tensor will be + Slice result_levels; + Slice index_levels; + int64_t tensor_insert_point = -1; + bool requires_getindex = false; + auto mark_tensor_index = [&] { + if (tensor_insert_point == -1) { + tensor_insert_point = result_levels.size(); + } else if (tensor_insert_point != result_levels.size()) { + tensor_insert_point = 0; + } + }; + for (auto i : flat_inputs.enumerate()) { + auto inp = flat_inputs[i]; + if (tensor_inputs[i]) { + requires_getindex = true; + mark_tensor_index(); + for (auto l : tensor_inputs[i].levels) { + // std::cout << "Consider to add " << l << "\n"; + if (!index_levels.contains(l)) { + index_levels.append(A, l); + } + } + } else if (Dim::check_exact(inp)) { + auto d = Dim::unchecked_wrap(inp); + // dimensions used once are just binding operations + if (1 == seen_dims_nuses[*seen_dims.index(d)]) { + flat_inputs[i] = no_slice; + result_levels.append(A, d); + } else { + requires_getindex = true; + flat_inputs[i] = mpy::handle(); + tensor_inputs[i] = TensorInfo{ + d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; + if (!index_levels.contains(d)) { + index_levels.append(A, d); + } + mark_tensor_index(); + } + } else { + if (inp.ptr() != no_slice.ptr()) { + requires_getindex = true; + } + if (!mpy::is_int(inp)) { + // note: actual positional indexes are accurately computed later + result_levels.append(A, -1); + } + } + } + + // indexing dimensions appear in the tensor at the _first use of a tensor_ in + // the indexing. So insert the indexing leveles into the result klevels at + // this spot + if (tensor_insert_point != -1) { + result_levels.insert( + A, + result_levels.slice(tensor_insert_point, tensor_insert_point), + index_levels); + } + + // std::cout << "flat inputs: " << flat_inputs << "\n"; + // std::cout << "result_levels: " << result_levels << "\n"; + // std::cout << "index_levels: " << index_levels << "\n"; + + // get all the tensors to be the right shape for indexing + if (requires_getindex) { + for (auto i : flat_inputs.enumerate()) { + if (tensor_inputs[i]) { + AT_ASSERT(!flat_inputs[i].ptr()); + // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << + // "\n"; + TensorRef t = tensor_inputs[i].tensor; + if (!tensor_inputs[i].has_device && device_holding_tensor) { + t = A.autorelease(t->to(device_holding_tensor->device())); + } + flat_inputs[i] = handle_from_tensor( + A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); + } + } + } + + // previously we didn't know how many positional dimensions there would be so + // we couldn't number them right so fill it in now. + auto seen_positionals = 0; + for (auto i : result_levels.reversed_enumerate()) { + if (result_levels[i].is_positional()) { + result_levels[i] = -(++seen_positionals); + } + } + + return IndexingInfo{ + false, + requires_getindex, + self_info.tensor, + flat_inputs, + result_levels, + self_info.has_device}; +} +namespace { +mpy::object __getitem__(Arena& A, mpy::handle self, mpy::handle index) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self)); + if (iinfo.can_call_original) { + return mpy::object::checked_steal( + THPVariable_getitem(self.ptr(), index.ptr())); + } + + return invoke_getitem(A, iinfo); +} + +void __setitem__( + Arena& A, + mpy::handle self, + mpy::handle index, + mpy::handle rhs) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); + if (iinfo.can_call_original) { + if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + return; + } + + auto rhs_info = TensorInfo::create(A, rhs, false, false); + if (rhs_info) { // otherwise rhs can be a scalar... + for (auto l : rhs_info.levels) { + if (!iinfo.result_levels.contains(l)) { + if (l.is_positional()) { + mpy::raise_error( + DimensionBindError(), + "rhs contains too many dimensions (%d) compared to indexed value (%d)", + ndim_of_levels(iinfo.result_levels), + rhs_info.ndim()); + } else { + auto tup = levels_to_tuple(iinfo.result_levels); + mpy::raise_error( + DimensionBindError(), + "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", + l.dim().ptr(), + tup.ptr()); } - auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); - rhs = handle_from_tensor(A, rhs_matched); + } } - self = handle_from_tensor(A, iinfo.self); + auto rhs_matched = + _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); + rhs = handle_from_tensor(A, rhs_matched); + } + self = handle_from_tensor(A, iinfo.self); - if (iinfo.advanced_indexing) { - auto tup = slice_to_tuple(iinfo.flat_inputs); - if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { - throw mpy::exception_set(); - } - } else { - torch_Tensor_copy_.call(self, rhs); + if (iinfo.advanced_indexing) { + auto tup = slice_to_tuple(iinfo.flat_inputs); + if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { + throw mpy::exception_set(); } + } else { + torch_Tensor_copy_.call(self, rhs); + } } -} +} // namespace PyObject* Tensor_getitem(PyObject* self, PyObject* index) { - Arena A; - PY_BEGIN - return __getitem__(A, self, index).release(); - PY_END(nullptr); + Arena A; + PY_BEGIN + return __getitem__(A, self, index).release(); + PY_END(nullptr); } int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { - Arena A; - PY_BEGIN - __setitem__(A, self, index, value); - return 0; - PY_END(-1); -} - -namespace{ -PyObject* py___getitem__(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 2); - return __getitem__(A, args[0], args[1]).release(); - PY_END(nullptr) -} - -PyObject* py___setitem__(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 3); - __setitem__(A, args[0], args[1], args[2]); - Py_RETURN_NONE; - PY_END(nullptr) -} - - -PyObject* py_index(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, dims, indices; - va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); - return index(A, self, dims, indices).release(); - PY_END(nullptr) -} - - -PyObject* py_stack(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle tensors, new_dim, dim; - va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); - - Slice result_levels; - Slice infos; - mpy::sequence_view sv(tensors); - auto new_dim_d = Dim::wrap(new_dim); - for (auto i : sv.enumerate()) { - infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); - for (auto l : infos.back().levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } - } - new_dim_d->set_size(infos.size()); - std::vector inputs; - inputs.reserve(infos.size()); - for (auto in : infos) { - inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); - } - auto ndim = ndim_of_levels(result_levels); - int64_t rawdim = 0; - if (dim.ptr()) { - auto d = _wrap_dim(dim, ndim, false); - auto idx = result_levels.index(d); - if (!idx) { - mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); - } - rawdim = *idx; - } - auto result = at::stack(inputs, rawdim); - result_levels.insert(A, rawdim, new_dim_d); - return Tensor::from_positional(A, std::move(result), result_levels, true).release(); - PY_END(nullptr) -} - -PyObject* py_split(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, split_size_or_sections, dim; - va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2); - bool dim_is_object = dim.ptr() && Dim::check_exact(dim); - Slice sizes; - - bool all_dims = true; - bool all_ints = true; - - if (!mpy::is_int(split_size_or_sections)) { - mpy::sequence_view sv(split_size_or_sections); - for (auto i : sv.enumerate()) { - sizes.append(A, A.autorelease(sv[i])); - if (Dim::check_exact(sizes.back())) { - all_ints = false; - } else { - all_dims = false; - } - } - } - if (all_ints) { - if (dim_is_object) { - mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions."); - } - // call original split (if self has dimensions this will use torch function to do the split) - return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release(); - } - if (!all_dims) { - mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix"); - } - - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - if (!dim_is_object&& ndim == 0) { - mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor"); - } - DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; - - auto idx = self_info.levels.index(dim_l); + Arena A; + PY_BEGIN + __setitem__(A, self, index, value); + return 0; + PY_END(-1); +} + +namespace { +PyObject* py___getitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 2); + return __getitem__(A, args[0], args[1]).release(); + PY_END(nullptr) +} + +PyObject* py___setitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 3); + __setitem__(A, args[0], args[1], args[2]); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* py_index( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, dims, indices; + va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); + return index(A, self, dims, indices).release(); + PY_END(nullptr) +} + +PyObject* py_stack( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle tensors, new_dim, dim; + va.parse( + "stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); + + Slice result_levels; + Slice infos; + mpy::sequence_view sv(tensors); + auto new_dim_d = Dim::wrap(new_dim); + for (auto i : sv.enumerate()) { + infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); + for (auto l : infos.back().levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + new_dim_d->set_size(infos.size()); + std::vector inputs; + inputs.reserve(infos.size()); + for (auto in : infos) { + inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); + } + auto ndim = ndim_of_levels(result_levels); + int64_t rawdim = 0; + if (dim.ptr()) { + auto d = _wrap_dim(dim, ndim, false); + auto idx = result_levels.index(d); if (!idx) { - if (!dim.ptr()) { - dim = A.autorelease(mpy::from_int(0)); - } - mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); - } - Slice indices; - - int64_t total_size = 0; - Slice unbound; - for (auto i : sizes.enumerate()) { - auto d = Dim::unchecked_wrap(sizes[i]); - if (d->is_bound()) { - indices.append(A, d->size()); - total_size += indices.back(); - } else { - indices.append(A, 0); - unbound.append(A, i); - } - } - auto tensor_size = self_info.tensor->sizes()[*idx]; - - if (unbound.size()) { - if (total_size > tensor_size) { - mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size)); - } - auto remaining_size = tensor_size - total_size; - auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); - for (auto u : unbound) { - auto sz = std::min(chunk_size, remaining_size); - Dim::unchecked_wrap(sizes[u])->set_size(sz); - indices[u] = sz; - remaining_size -= sz; - } - } else if (tensor_size != total_size) { - mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size)); - } - - auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); - mpy::tuple result(result_tensors.size()); - Slice new_levels; - new_levels.extend(A, self_info.levels); - for (auto i : sizes.enumerate()) { - new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); - result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); - } - - return result.release(); - - PY_END(nullptr) + mpy::raise_error( + PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); + } + rawdim = *idx; + } + auto result = at::stack(inputs, rawdim); + result_levels.insert(A, rawdim, new_dim_d); + return Tensor::from_positional(A, std::move(result), result_levels, true) + .release(); + PY_END(nullptr) +} + +PyObject* py_split( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, split_size_or_sections, dim; + va.parse( + "split", + {"self", "split_size_or_sections", "dim"}, + {&self, &split_size_or_sections, &dim}, + 2); + bool dim_is_object = dim.ptr() && Dim::check_exact(dim); + Slice sizes; + + bool all_dims = true; + bool all_ints = true; + + if (!mpy::is_int(split_size_or_sections)) { + mpy::sequence_view sv(split_size_or_sections); + for (auto i : sv.enumerate()) { + sizes.append(A, A.autorelease(sv[i])); + if (Dim::check_exact(sizes.back())) { + all_ints = false; + } else { + all_dims = false; + } + } + } + if (all_ints) { + if (dim_is_object) { + mpy::raise_error( + PyExc_TypeError, + "when dim is specified as a Dim object, split sizes must also be dimensions."); + } + // call original split (if self has dimensions this will use torch function + // to do the split) + return torch_Tensor_split + .call_vector(mpy::vector_args(args, nargs, kwnames)) + .release(); + } + if (!all_dims) { + mpy::raise_error( + PyExc_TypeError, "split list must be ints or dims but got a mix"); + } + + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + if (!dim_is_object && ndim == 0) { + mpy::raise_error( + PyExc_TypeError, "split expects at least a 1-dimension tensor"); + } + DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; + + auto idx = self_info.levels.index(dim_l); + if (!idx) { + if (!dim.ptr()) { + dim = A.autorelease(mpy::from_int(0)); + } + mpy::raise_error( + PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + } + Slice indices; + + int64_t total_size = 0; + Slice unbound; + for (auto i : sizes.enumerate()) { + auto d = Dim::unchecked_wrap(sizes[i]); + if (d->is_bound()) { + indices.append(A, d->size()); + total_size += indices.back(); + } else { + indices.append(A, 0); + unbound.append(A, i); + } + } + auto tensor_size = self_info.tensor->sizes()[*idx]; + + if (unbound.size()) { + if (total_size > tensor_size) { + mpy::raise_error( + PyExc_TypeError, + "sizes of target dimensions add up to more (%d) than source dim (%d)", + int(total_size), + int(tensor_size)); + } + auto remaining_size = tensor_size - total_size; + auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); + for (auto u : unbound) { + auto sz = std::min(chunk_size, remaining_size); + Dim::unchecked_wrap(sizes[u])->set_size(sz); + indices[u] = sz; + remaining_size -= sz; + } + } else if (tensor_size != total_size) { + mpy::raise_error( + PyExc_TypeError, + "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", + int(total_size), + int(tensor_size)); + } + + auto result_tensors = self_info.tensor->split_with_sizes( + at::IntArrayRef(indices.begin(), indices.end()), *idx); + mpy::tuple result(result_tensors.size()); + Slice new_levels; + new_levels.extend(A, self_info.levels); + for (auto i : sizes.enumerate()) { + new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); + result.set( + i, + Tensor::from_positional( + A, std::move(result_tensors[i]), new_levels, true)); + } + + return result.release(); + + PY_END(nullptr) } Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { - auto de = _wrap_dim(d, N, keepdim); - Slice r; - if (!de.is_none()) { - r.append(A, de); - } else { - mpy::sequence_view sq(d); - for (auto i : sq.enumerate()) { - r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); - } + auto de = _wrap_dim(d, N, keepdim); + Slice r; + if (!de.is_none()) { + r.append(A, de); + } else { + mpy::sequence_view sq(d); + for (auto i : sq.enumerate()) { + r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); } - return r; + } + return r; } struct WrappedOperator : public mpy::base { - mpy::object orig; - PyMethodDef method_def; - mpy::object name, doc; - - bool is_pointwise = false; - int64_t dim_offset = 0; - int64_t keepdim_offset = 1; - std::string dim_name; - bool single_dim = false; - bool reduce = true; - - static PyTypeObject Type; - - void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") { - orig = std::move(orig_); - method_def.ml_meth = wrapper_implementation; - name = orig.attr("__name__"); - doc = orig.attr("__doc__"); - dim_name = std::move(dim_name_); - if (!mpy::is_none(doc) && !dim_name.empty()) { - doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str()); - } - method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); - method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); - method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; - } - - mpy::object function() { - return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); - } - + mpy::object orig; + PyMethodDef method_def; + mpy::object name, doc; + + bool is_pointwise = false; + int64_t dim_offset = 0; + int64_t keepdim_offset = 1; + std::string dim_name; + bool single_dim = false; + bool reduce = true; + + static PyTypeObject Type; + + void init( + mpy::object orig_, + PyCFunction wrapper_implementation, + std::string dim_name_ = "") { + orig = std::move(orig_); + method_def.ml_meth = wrapper_implementation; + name = orig.attr("__name__"); + doc = orig.attr("__doc__"); + dim_name = std::move(dim_name_); + if (!mpy::is_none(doc) && !dim_name.empty()) { + doc = mpy::unicode_from_format( + "%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", + doc.ptr(), + dim_name.c_str()); + } + method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); + method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); + method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; + } + + mpy::object function() { + return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + } }; -} +} // namespace PyTypeObject WrappedOperator::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.WrappedOperator", /* tp_name */ - sizeof(WrappedOperator), /* tp_basicsize */ - 0, /* tp_itemsize */ - WrappedOperator::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ + "_C.WrappedOperator", /* tp_name */ + sizeof(WrappedOperator), /* tp_basicsize */ + 0, /* tp_itemsize */ + WrappedOperator::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ - "Wrapped Object Holder", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - WrappedOperator::new_stub, /* tp_new */ + "Wrapped Object Holder", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + WrappedOperator::new_stub, /* tp_new */ }; -namespace{ -PyObject* patched_dim_method(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - auto self = WrappedOperator::unchecked_wrap(self_); - PY_BEGIN - - mpy::vector_args va(args, nargs, kwnames); - - auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - return idx == -1 ? mpy::handle() : va[idx]; - }; - Slice patched_args; - patched_args.extend(A, va.begin(), va.end()); - auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - if (idx == -1) { - mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); - } - patched_args[idx] = value; - }; - - auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); - if (!dim.ptr()) { - auto info = TensorInfo::create(A, args[0], true); - EnableAllLayers l(A, info.levels); - l.inplace_update_layers(info.batchedtensor, info.levels); - patched_args[0] = handle_from_tensor(A, info.batchedtensor); - auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); - } - - auto info = TensorInfo::create(A, args[0]); - auto keepdim = false; - if (self->reduce) { - auto py_keepdim = _getarg("keepdim", self->keepdim_offset); - if (py_keepdim.ptr()) { - keepdim = mpy::to_bool(py_keepdim); - } - } - - auto ndim = info.ndim(); - auto dims = _wrap_dims(A, dim, ndim, keepdim); - Slice dim_indices; - auto seen = A.allocate(info.levels.size()); - std::fill(seen, seen + info.levels.size(), false); - - for (auto d : dims) { - auto midx = info.levels.index(d); - if (!midx) { - auto tup = levels_to_tuple(info.levels); - mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr()); - } - seen[*midx] = true; - dim_indices.append(A, *midx); - } - Slice new_levels; - if (self->reduce && !keepdim) { - for (auto i : info.levels.enumerate()) { - if (!seen[i]) { - new_levels.append(A, info.levels[i]); - } - } - } else { - new_levels = info.levels; - } - mpy::object py_indices; - if (dim_indices.size() == 1) { - py_indices = mpy::from_int(dim_indices[0]); - } else { - mpy::tuple tup(dim_indices.size()); - for (auto i : dim_indices.enumerate()) { - tup.set(i, mpy::from_int(dim_indices[i])); - } - py_indices = std::move(tup); - } - _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); - patched_args[0] = handle_from_tensor(A, info.tensor); +namespace { +PyObject* patched_dim_method( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + auto self = WrappedOperator::unchecked_wrap(self_); + PY_BEGIN + + mpy::vector_args va(args, nargs, kwnames); + + auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + return idx == -1 ? mpy::handle() : va[idx]; + }; + Slice patched_args; + patched_args.extend(A, va.begin(), va.end()); + auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + if (idx == -1) { + mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); + } + patched_args[idx] = value; + }; + + auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); + if (!dim.ptr()) { + auto info = TensorInfo::create(A, args[0], true); + EnableAllLayers l(A, info.levels); + l.inplace_update_layers(info.batchedtensor, info.levels); + patched_args[0] = handle_from_tensor(A, info.batchedtensor); auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); - } - return h; - }; - return tree_map(A, wrap, r).release(); - PY_END(nullptr) -} - -PyObject* _wrap(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - - #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ - _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce) - MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) - - std::string dim_name_str; - if (dim_name.ptr()) { - dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); - } else { - dim_name_str = "dim"; - } - auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); - if (dim_offset.ptr()) { - info->dim_offset = mpy::to_int(dim_offset); - } - if (keepdim_offset.ptr()) { - info->keepdim_offset = mpy::to_int(keepdim_offset); - } - - if (single_dim.ptr()) { - info->single_dim = mpy::to_bool(single_dim); - } - if (reduce.ptr()) { - info->reduce = mpy::to_bool(reduce); - } - return info->function().release(); - #undef ARGS - - PY_END(nullptr) -} - -PyObject* call_torch_function(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - Arena A; - maybeInitializeGlobals(); - auto info = WrappedOperator::unchecked_wrap(self); - return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release(); - PY_END(nullptr) -} - -PyObject* _wrap_method(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - AT_ASSERT(nargs == 2); - // XXX - ignore python function wrapped, we will call torch function directly - mpy::handle orig = args[0]; - if (!pointwise.ptr()) { - auto dim = mpy::import("functorch.dim"); - pointwise = dim.attr("pointwise"); - } - auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function); - info->is_pointwise = pointwise.contains(orig); - return PyInstanceMethod_New(info->function().release()); - PY_END(nullptr); -} - - -PyObject* Tensor_sum(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - auto self_ = Tensor::unchecked_wrap(args[0]); - auto d = self_->delayed(); - if (!d) { - return _Tensor_sum.call_vector(va).release(); - } - mpy::handle self, dim, keepdim, dtype; - va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1); - - if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { - // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; - return _Tensor_sum.call_vector(va).release(); - } - auto levels = self_->levels(); - - auto N = ndim_of_levels(levels); - auto reduced_dims = _wrap_dims(A, dim, N, false); - - return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); - PY_END(nullptr) -} - -PyObject* _parse_test(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - maybeInitializeGlobals(); - - int required = mpy::to_int(args[0]); - int kwonly = mpy::to_int(args[1]); - - mpy::vector_args va(args + 2, nargs - 2, kwnames); - - - mpy::handle a, b, c, d; - va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); - mpy::tuple r(4); - r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); - r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); - r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); - r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); - return r.release(); - - PY_END(nullptr) -} - -PyObject* _set_pointwise_optimize(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - mpy::handle value; - mpy::vector_args va(args, nargs, kwnames); - va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); - pointwise_optimize = mpy::to_bool(value); - Py_RETURN_NONE; - PY_END(nullptr) -} - -PyObject* _patch_tensor_class(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - - auto torch = mpy::import("torch"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - replaceMappingIfMatches(py_TensorBase); - - Py_RETURN_NONE; - PY_END(nullptr) + return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device) + .release(); + } + + auto info = TensorInfo::create(A, args[0]); + auto keepdim = false; + if (self->reduce) { + auto py_keepdim = _getarg("keepdim", self->keepdim_offset); + if (py_keepdim.ptr()) { + keepdim = mpy::to_bool(py_keepdim); + } + } + + auto ndim = info.ndim(); + auto dims = _wrap_dims(A, dim, ndim, keepdim); + Slice dim_indices; + auto seen = A.allocate(info.levels.size()); + std::fill(seen, seen + info.levels.size(), false); + + for (auto d : dims) { + auto midx = info.levels.index(d); + if (!midx) { + auto tup = levels_to_tuple(info.levels); + mpy::raise_error( + PyExc_ValueError, + "Tensor with dimensions %R does not contain one of %R\n", + tup.ptr(), + dim.ptr()); + } + seen[*midx] = true; + dim_indices.append(A, *midx); + } + Slice new_levels; + if (self->reduce && !keepdim) { + for (auto i : info.levels.enumerate()) { + if (!seen[i]) { + new_levels.append(A, info.levels[i]); + } + } + } else { + new_levels = info.levels; + } + mpy::object py_indices; + if (dim_indices.size() == 1) { + py_indices = mpy::from_int(dim_indices[0]); + } else { + mpy::tuple tup(dim_indices.size()); + for (auto i : dim_indices.enumerate()) { + tup.set(i, mpy::from_int(dim_indices[i])); + } + py_indices = std::move(tup); + } + _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); + patched_args[0] = handle_from_tensor(A, info.tensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); + } + return h; + }; + return tree_map(A, wrap, r).release(); + PY_END(nullptr) +} + +PyObject* _wrap( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + +#define ARGS(_) \ + _(mpy::handle, orig) \ + _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ + _(mpy::handle, dim_name) _(mpy::handle, single_dim) \ + _(mpy::handle, reduce) + MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) + + std::string dim_name_str; + if (dim_name.ptr()) { + dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); + } else { + dim_name_str = "dim"; + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), + (PyCFunction)(void*)patched_dim_method, + std::move(dim_name_str)); + if (dim_offset.ptr()) { + info->dim_offset = mpy::to_int(dim_offset); + } + if (keepdim_offset.ptr()) { + info->keepdim_offset = mpy::to_int(keepdim_offset); + } + + if (single_dim.ptr()) { + info->single_dim = mpy::to_bool(single_dim); + } + if (reduce.ptr()) { + info->reduce = mpy::to_bool(reduce); + } + return info->function().release(); +#undef ARGS + + PY_END(nullptr) +} + +PyObject* call_torch_function( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Arena A; + maybeInitializeGlobals(); + auto info = WrappedOperator::unchecked_wrap(self); + return __torch_function__( + A, + info->orig, + mpy::vector_args(args, nargs, kwnames), + info->is_pointwise) + .release(); + PY_END(nullptr) +} + +PyObject* _wrap_method( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + AT_ASSERT(nargs == 2); + // XXX - ignore python function wrapped, we will call torch function directly + mpy::handle orig = args[0]; + if (!pointwise.ptr()) { + auto dim = mpy::import("functorch.dim"); + pointwise = dim.attr("pointwise"); + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), (PyCFunction)(void*)call_torch_function); + info->is_pointwise = pointwise.contains(orig); + return PyInstanceMethod_New(info->function().release()); + PY_END(nullptr); +} + +PyObject* Tensor_sum( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + auto self_ = Tensor::unchecked_wrap(args[0]); + auto d = self_->delayed(); + if (!d) { + return _Tensor_sum.call_vector(va).release(); + } + mpy::handle self, dim, keepdim, dtype; + va.parse( + "sum", + {"self", "dim", "keepdim", "dtype"}, + {&self, &dim, &keepdim, &dtype}, + 1, + 1); + + if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { + // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; + return _Tensor_sum.call_vector(va).release(); + } + auto levels = self_->levels(); + + auto N = ndim_of_levels(levels); + auto reduced_dims = _wrap_dims(A, dim, N, false); + + return dot(A, + TensorInfo::create(A, d->args[0], false), + TensorInfo::create(A, d->args[1], false), + reduced_dims) + .release(); + PY_END(nullptr) +} + +PyObject* _parse_test( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + maybeInitializeGlobals(); + + int required = mpy::to_int(args[0]); + int kwonly = mpy::to_int(args[1]); + + mpy::vector_args va(args + 2, nargs - 2, kwnames); + + mpy::handle a, b, c, d; + va.parse( + "_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); + mpy::tuple r(4); + r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); + r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); + r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); + r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); + return r.release(); + + PY_END(nullptr) +} + +PyObject* _set_pointwise_optimize( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle value; + mpy::vector_args va(args, nargs, kwnames); + va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); + pointwise_optimize = mpy::to_bool(value); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* _patch_tensor_class( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + + auto torch = mpy::import("torch"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + replaceMappingIfMatches(py_TensorBase); + + Py_RETURN_NONE; + PY_END(nullptr) } - const char* dims_doc = R"""( dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] @@ -3196,54 +3579,79 @@ Example:: )"""; PyMethodDef methods[] = { - {"dims", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS, dims_doc}, - {"dimlists", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS}, - {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, - {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, - {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, - {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, - {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, - {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, - {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, - {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, - {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, - {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, - {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, - {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, - {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, - {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, - {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"dims", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS, + dims_doc}, + {"dimlists", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS}, + {"_test_c", (PyCFunction)(void*)test_c, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap_method", + (PyCFunction)(void*)_wrap_method, + METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_from_positional", + (PyCFunction)(void*)py_Tensor_from_positional, + METH_FASTCALL | METH_KEYWORDS}, + {"__torch_function__", + (PyCFunction)(void*)py___torch_function__, + METH_FASTCALL | METH_KEYWORDS}, + {"tree_flatten", + (PyCFunction)(void*)py_tree_flatten, + METH_FASTCALL | METH_KEYWORDS}, + {"order", (PyCFunction)(void*)order, METH_FASTCALL | METH_KEYWORDS}, + {"index", (PyCFunction)(void*)py_index, METH_FASTCALL | METH_KEYWORDS}, + {"stack", (PyCFunction)(void*)py_stack, METH_FASTCALL | METH_KEYWORDS}, + {"split", (PyCFunction)(void*)py_split, METH_FASTCALL | METH_KEYWORDS}, + {"expand", (PyCFunction)(void*)expand, METH_FASTCALL | METH_KEYWORDS}, + {"__getitem__", + (PyCFunction)(void*)py___getitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"__setitem__", + (PyCFunction)(void*)py___setitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"_wrap", (PyCFunction)(void*)_wrap, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_sum", + (PyCFunction)(void*)Tensor_sum, + METH_FASTCALL | METH_KEYWORDS}, + {"_parse_test", + (PyCFunction)(void*)_parse_test, + METH_FASTCALL | METH_KEYWORDS}, + {"_set_pointwise_optimize", + (PyCFunction)(void*)_set_pointwise_optimize, + METH_FASTCALL | METH_KEYWORDS}, + {"_patch_tensor_class", + (PyCFunction)(void*)_patch_tensor_class, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, - "_C", /* name of module */ + "_C", /* name of module */ NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - methods -}; -} + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + methods}; +} // namespace PyObject* Dim_init() { - Arena A; - try { - mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); - Dim::ready(mod, "Dim"); - DimList::ready(mod, "DimList"); - Tensor::ready(mod, "Tensor"); - WrappedOperator::ready(mod, "_WrappedOperator"); - Py_INCREF(&PyInstanceMethod_Type); - PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type); - - initializeGlobals(A); - return mod.release(); - } catch(mpy::exception_set& err) { - return nullptr; - } + Arena A; + try { + mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); + Dim::ready(mod, "Dim"); + DimList::ready(mod, "DimList"); + Tensor::ready(mod, "Tensor"); + WrappedOperator::ready(mod, "_WrappedOperator"); + Py_INCREF(&PyInstanceMethod_Type); + PyModule_AddObject( + mod.ptr(), "_instancemethod", (PyObject*)&PyInstanceMethod_Type); + + initializeGlobals(A); + return mod.release(); + } catch (mpy::exception_set& err) { + return nullptr; + } } #endif diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7df18543ddc44..d4c98a0f6b151 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -582,7 +582,6 @@ "torch._C._dispatch_has_kernel", "torch._C._dispatch_is_alias_key", "torch._C._dispatch_is_included_in_alias", - "torch._C._dispatch_is_main_interpreter", "torch._C._dispatch_isTensorSubclassLike", "torch._C._dispatch_key_for_device", "torch._C._dispatch_key_name", diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 5a0f4a59abe30..8aabf24c4c1b3 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -407,10 +407,10 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // associated with the TensorImpl. Swap this field as well. std::optional mb_obj_a = a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); std::optional mb_obj_b = b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( mb_obj_a.has_value() && mb_obj_b.has_value(), "Both tensors should have PyObjects tagged by the current python interpreter"); @@ -420,10 +420,8 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { a->cdata = b->cdata; b->cdata = tmp; - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); + a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); + b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index f944bb5c5461e..f289a286b19c7 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -586,7 +586,7 @@ static void set_tensor_attr_with_capsule( py::capsule& capsule, const char* attr_name) { std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto obj = mb_obj.value(); @@ -987,7 +987,3 @@ py::handle getTorchApiFunction(const c10::OperatorHandle& op) { c10::impl::PyInterpreter* getPyInterpreter() { return torch::detail::self_interpreter.get(); } - -bool isMainPyInterpreter() { - return torch::detail::self_interpreter.is_main_interpreter(); -} diff --git a/torch/csrc/PyInterpreter.h b/torch/csrc/PyInterpreter.h index 82ca11e2c5d0c..0ff9f79d02c27 100644 --- a/torch/csrc/PyInterpreter.h +++ b/torch/csrc/PyInterpreter.h @@ -10,4 +10,4 @@ TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op); // TODO: Move these to a proper namespace TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); -TORCH_PYTHON_API bool isMainPyInterpreter(); +TORCH_PYTHON_API void initializeGlobalPyInterpreter(); diff --git a/torch/csrc/PyInterpreterHooks.cpp b/torch/csrc/PyInterpreterHooks.cpp new file mode 100644 index 0000000000000..fd1c997be0a08 --- /dev/null +++ b/torch/csrc/PyInterpreterHooks.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace torch::detail { + +PyInterpreterHooks::PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs) {} + +c10::impl::PyInterpreter* PyInterpreterHooks::getPyInterpreter() const { + // Delegate to the existing implementation + return ::getPyInterpreter(); +} + +} // namespace torch::detail + +// Sigh, the registry doesn't support namespaces :( +using c10::impl::PyInterpreterHooksRegistry; +using c10::impl::RegistererPyInterpreterHooksRegistry; +using PyInterpreterHooks = torch::detail::PyInterpreterHooks; +// Register the implementation +REGISTER_PYTHON_HOOKS(PyInterpreterHooks); diff --git a/torch/csrc/PyInterpreterHooks.h b/torch/csrc/PyInterpreterHooks.h new file mode 100644 index 0000000000000..1def7b8c55ae6 --- /dev/null +++ b/torch/csrc/PyInterpreterHooks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch::detail { + +// Concrete implementation of PyInterpreterHooks +class PyInterpreterHooks : public c10::impl::PyInterpreterHooksInterface { + public: + explicit PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs); + + c10::impl::PyInterpreter* getPyInterpreter() const override; +}; + +} // namespace torch::detail diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index cc682a2644af2..08112b41aaaed 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -35,7 +35,6 @@ PyTypeObject* THPStorageClass = nullptr; PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { TORCH_CHECK( PyType_IsSubtype(type, &THPStorageType), @@ -43,7 +42,7 @@ PyObject* THPStorage_NewWithStorage( "Storage is not possible. Make sure your class inherits from Storage."); auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); if (maybe_pyobj.has_value() && maybe_pyobj.value()) { TORCH_CHECK( allow_preexisting_pyobj, @@ -78,8 +77,7 @@ PyObject* THPStorage_NewWithStorage( if (!c10::impl::HermeticPyObjectTLS::get_state()) { s->is_hermetic = false; const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), obj, status); + storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); } else { s->is_hermetic = true; } @@ -91,17 +89,12 @@ PyObject* THPStorage_NewWithStorage( PyObject* THPStorage_Wrap(c10::Storage storage) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(storage), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); std::optional maybe_pyobj = pyobj_slot->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); - c10::impl::PyInterpreterStatus status = - c10::impl::PyInterpreterStatus::TAGGED_BY_US; + /*ignore_hermetic_tls=*/false); if (maybe_pyobj.has_value()) { auto obj = *maybe_pyobj; if (obj) { @@ -120,15 +113,8 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { return obj; } } - status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; - } else { - if (storage.use_count() <= 1) { - status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; - } else { - status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; - } } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); } static bool THPStorage_isPreservable(THPStorage* self) { @@ -142,8 +128,7 @@ static bool THPStorage_isPreservable(THPStorage* self) { } if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/true) != - (PyObject*)self) { + /*ignore_hermetic_tls=*/true) != (PyObject*)self) { return false; } if (storage.use_count() <= 1) { @@ -161,11 +146,10 @@ static bool THPStorage_tryPreserve(THPStorage* self) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/true); // NOTE: It is possible to just set the PyObjectSlot here, but the point is - // that we should have already set PyObjectSlot when the storage PyObject was - // created. + // that we should have already set PyObjectSlot when the storage PyObject + // was created. TORCH_INTERNAL_ASSERT( maybe_pyobj.has_value(), "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); @@ -373,8 +357,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); // torch.Storage(size, *, ...) } else if (r.idx == 1) { @@ -387,8 +370,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); // torch.Storage(sequence, *, ...) } else if (r.idx == 2) { @@ -412,8 +394,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); THPObjectPtr item; try { const auto& storage = THPStorage_Unpack(self); @@ -509,10 +490,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { /* resizable */ false, device_opt); - PyObject* _ret = THPStorage_NewWithStorage( - Py_TYPE(self), - std::move(new_storage_impl), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + PyObject* _ret = + THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl)); return _ret; } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index ce86475d6a952..698cd80548efa 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -19,7 +19,6 @@ TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 8e5a99e4da7f7..da64bcfbd5008 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -390,10 +390,7 @@ static PyObject* THPStorage_fromFile( storage->set_nbytes(actual_nbytes); } - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(storage), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index 9f7d667613dc5..e58865bb60a8a 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -86,8 +86,7 @@ static PyObject* THPStorage_pyNewFilenameStorage( THManagedMapAllocator::makeDataPtr( "", handle.c_str(), flags, static_cast(size)), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -182,8 +181,7 @@ static PyObject* THPStorage_newSharedFilename( THManagedMapAllocator::makeDataPtr( manager_handle, object_handle, flags, size), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -197,9 +195,7 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { return nullptr; } return THPStorage_NewWithStorage( - THPStorageClass, - at::new_shm_fd_storage(size), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + THPStorageClass, at::new_shm_fd_storage(size)); END_HANDLE_TH_ERRORS } @@ -278,8 +274,7 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { at::MapAllocator::makeDataPtr( at::WITH_FD, "", fd, flags, size, nullptr), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -560,10 +555,7 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { base->set_resizable(false); base->set_received_cuda(true); - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(base), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + return THPStorage_NewWithStorage(THPStorageClass, std::move(base)); #else TORCH_CHECK(false, "CUDA is not available"); #endif diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index b0235da869fbc..c184dd63d2949 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -209,7 +209,6 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); // clang-tidy gets confused by static const @@ -261,16 +260,12 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { } if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, - var, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } std::optional mb_obj = var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); - c10::impl::PyInterpreterStatus status{}; + /*ignore_hermetic_tls=*/false); if (mb_obj.has_value()) { auto obj = *mb_obj; if (obj) { @@ -295,27 +290,17 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR // being a thing, the PyObject field will get cleared when all references // to the Python object are removed. - status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; - } else { - // Assumption: if a Tensor has been shared across threads, this induces - // a refcount bump. Therefore, if the use count 1, we are the sole thread - // with access to this tensor and no race is possible. - if (var.use_count() <= 1) { - status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; - } else { - status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; - } } if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); + return THPVariable_NewWithVar((PyTypeObject*)clazz, var); } - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } static bool isResurrectable(THPVariable* self) { @@ -344,8 +329,7 @@ static bool isResurrectable(THPVariable* self) { } // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false) != - (PyObject*)self) { + /*ignore_hermetic_tls=*/false) != (PyObject*)self) { return false; } return true; @@ -371,7 +355,6 @@ static bool THPVariable_tryResurrect(THPVariable* self) { c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( @@ -587,10 +570,7 @@ static PyObject* THPVariable_as_subclass( // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - self.alias(), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); END_HANDLE_TH_ERRORS } @@ -642,10 +622,7 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - data, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, data); END_HANDLE_TH_ERRORS } @@ -790,10 +767,7 @@ static PyObject* THPVariable_make_wrapper_subclass( tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, tensor); END_HANDLE_TH_ERRORS } @@ -1821,7 +1795,6 @@ PyObject* THPVariable_pynew( return THPVariable_NewWithVar( type, tensor, - c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS } @@ -1874,8 +1847,7 @@ static int THPVariable_subclass_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false) == - (PyObject*)self) { + /*ignore_hermetic_tls=*/false) == (PyObject*)self) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -2047,17 +2019,10 @@ static void THPVariable_subclass_dealloc(PyObject* self) { Py_DECREF(type); } -// Creates a new Python object for a Variable. The status parameter -// specifies what the interpreter tag status on the object is; for -// example, if you ran check_pyobj, the return optional of this object -// tells you if the tensor was already tagged or not so you can pass -// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where -// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. -// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. +// Creates a new Python object for a Variable. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { // Make sure that the reinterpret into a THPVariable* will be valid TORCH_CHECK( @@ -2068,7 +2033,7 @@ static PyObject* THPVariable_NewWithVar( // This function overwrite the Tensor's pyobj field without extra checks // Make sure it is not set otherwise we would leak memory auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); // Under some circumstances, we may attempt to create a new Python // object for a variable that already has a Python object. The most common @@ -2150,8 +2115,7 @@ static PyObject* THPVariable_NewWithVar( // Normal codepath v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), obj, status); + var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); if (check_has_torch_dispatch(obj)) { var.unsafeGetTensorImpl()->set_python_dispatch(true); } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index b2b0e848a7e79..019ce2070634d 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -209,12 +209,10 @@ class PythonKernelHolder : public c10::OperatorKernel { } }; +// @todo sahanp: Afait only register is used in the codebase. This can be +// removed / simplified static torch::_RegisterOrVerify register_or_verify() { - if (isMainPyInterpreter()) { - return torch::_RegisterOrVerify::REGISTER; - } else { - return torch::_RegisterOrVerify::VERIFY; - } + return torch::_RegisterOrVerify::REGISTER; } static py::object ophandle_call_boxed( @@ -287,7 +285,6 @@ void initDispatchBindings(PyObject* module) { .def( "reset", [](const py::object& self) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().reset(); return; }, @@ -297,7 +294,6 @@ void initDispatchBindings(PyObject* module) { .def( "def_", [](py::object self, const char* schema, const char* alias) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; @@ -311,7 +307,6 @@ void initDispatchBindings(PyObject* module) { .def( "def_legacy", [](py::object self, const char* schema) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def(torch::jit::parseSchema(schema)); return self; }, @@ -331,7 +326,6 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -349,7 +343,6 @@ void initDispatchBindings(PyObject* module) { const char* dispatch, const char* alias, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { @@ -370,7 +363,6 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -465,7 +457,6 @@ void initDispatchBindings(PyObject* module) { .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; @@ -480,7 +471,6 @@ void initDispatchBindings(PyObject* module) { bool with_keyset) { HANDLE_TH_ERRORS auto& lib = self.cast(); - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); if (func.is(py::module::import("torch.library") .attr("fallthrough_kernel"))) { lib.fallback( @@ -913,8 +903,6 @@ void initDispatchBindings(PyObject* module) { handle.setReportErrorCallback_(std::move(callback_obj)); }); - m.def( - "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); m.def("_dispatch_pystub", [](const char* name, const char* overload) { return c10::Dispatcher::singleton().getPyStub( c10::OperatorName(name, overload)); From a00cd8cf252a0b061f2eef6b5b42ae967acf5f64 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 17 Jul 2025 18:02:09 -0700 Subject: [PATCH 233/457] Add a way to disable compile for debugging flex-attention (#158534) Finally got around to doing this, this flag lets us do: ```Python #!/usr/bin/env python3 """ FlexAttention Debug: Using breakpoints and unwrap """ import torch import torch.nn.attention.flex_attention as fa unwrap = torch._C._functorch.get_unwrapped def score_mod(score, batch, head, q_idx, kv_idx): # Set breakpoint here to debug breakpoint() # In debugger, unwrap to see actual tensor values: # >>> actual_score = unwrap(unwrap(unwrap(unwrap(score)))) # >>> actual_batch = unwrap(batch) # >>> actual_head = unwrap(head) # >>> actual_q_idx = unwrap(q_idx) # >>> actual_kv_idx = unwrap(kv_idx) # >>> print(actual_score) # >>> print(f"q_idx: {actual_q_idx}, kv_idx: {actual_kv_idx}") return torch.where(q_idx >= kv_idx, score, torch.tensor(float('-inf'))) def main(): # Enable debug mode fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True # Small example B, H, S, D = 1, 2, 4, 8 q = torch.randn(B, H, S, D) k = torch.randn(B, H, S, D) v = torch.randn(B, H, S, D) # Run - will hit breakpoint output = fa.flex_attention(q, k, v, score_mod=score_mod) # Disable debug mode fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False if __name__ == "__main__": main() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158534 Approved by: https://github.com/Chillee, https://github.com/zou3519 --- test/inductor/test_flex_attention.py | 47 +++++++++++++++++++++ torch/nn/attention/flex_attention.py | 62 ++++++++++++++++++++++++---- 2 files changed, 102 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index fa6400dd9c272..e14afcea81b05 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5,6 +5,7 @@ import random import string import unittest +import warnings from collections import namedtuple from contextlib import contextmanager from dataclasses import dataclass @@ -4235,6 +4236,52 @@ def test_large_batch_heads_grid_dimension(self, device): self.assertEqual(key.grad.shape, key.shape) self.assertEqual(value.grad.shape, value.shape) + @supported_platform + def test_debug_flag_disables_internal_compilation(self, device): + """Test that _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG flag bypasses internal compilation.""" + import torch.nn.attention.flex_attention as fa + + original_flag = fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG + original_warnings_shown = fa._WARNINGS_SHOWN.copy() + + try: + B, H, S, D = 1, 1, 128, 64 + query = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + key = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + value = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + + def simple_score_mod(score, b, h, q_idx, kv_idx): + return score + + # Test with debug flag False - should warn + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False + fa._WARNINGS_SHOWN.clear() + + with self.assertWarns(UserWarning) as cm: + out_compiled = fa.flex_attention( + query, key, value, score_mod=simple_score_mod + ) + + self.assertIn( + "flex_attention called without torch.compile", str(cm.warning) + ) + + # Test with debug flag True - should NOT warn + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True + + # Should not error + with warnings.catch_warnings(): + warnings.simplefilter("error") + out_debug = fa.flex_attention( + query, key, value, score_mod=simple_score_mod + ) + + torch.testing.assert_close(out_compiled, out_debug, rtol=1e-4, atol=1e-4) + + finally: + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag + fa._WARNINGS_SHOWN = original_warnings_shown + class TestBlockMask(InductorTestCase): def setUp(self): diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 160dba68f0ccf..7dc66696d1108 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -36,6 +36,36 @@ from torch.utils._pytree import tree_map_only +# Private debug flag to disable internal compilation wrapping for debugging purposes. +# WARNING: This is intended ONLY for debugging score_mod and mask_mod functions. +# When enabled, this bypasses the required internal compilation that ensures correctness +# and performance. Only use this temporarily when you need to set breakpoints +# in your score_mod/mask_mod functions during development. +# +# This flag only affects the internal compilation when flex_attention is called directly. +# If you have already wrapped flex_attention in torch.compile(), this flag has no effect +# and the user's compilation will still occur. +# +# Usage: +# import torch.nn.attention.flex_attention as fa +# fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True +# # Now you can set breakpoints in your score_mod/mask_mod +# output = fa.flex_attention(q, k, v, score_mod=my_score_mod) +# +_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False + +_WARNINGS_SHOWN: set[str] = set() + + +def _warn_once( + warning_id: str, message: str, category: type[Warning] = UserWarning +) -> None: + """Helper to ensure each warning is shown only once per process.""" + if warning_id not in _WARNINGS_SHOWN: + warnings.warn(message, category, stacklevel=2) + _WARNINGS_SHOWN.add(warning_id) + + __all__ = [ "BlockMask", "flex_attention", @@ -1548,6 +1578,18 @@ def score_mod( else: return out + if not _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + _warn_once( + warning_id="flex_attention_performance", + message=( + "flex_attention called without torch.compile() - this will use an unfused implementation that materializes the full scores matrix instead of generating a fused kernel.\n\n" + "SOLUTION: Use torch.compile(flex_attention)(...)\n\n" + "If you want to debug your score_mod/mask_mod, you can set:\n" + "torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True\n\n" + "This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results." + ), + ) + if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") @@ -1570,9 +1612,15 @@ def _flex_attention_hop_wrapper(*args, **kwargs): ) else: backend = "eager" - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend=backend, fullgraph=True - )( + + if _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + flex_fn = _flex_attention_hop_wrapper + else: + flex_fn = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + ) + + out, lse = flex_fn( query, key, value, @@ -1581,7 +1629,7 @@ def _flex_attention_hop_wrapper(*args, **kwargs): scale, kernel_options, ) - if return_lse: - return out, lse * math.log(2) - else: - return out + if return_lse: + return out, lse * math.log(2) + else: + return out From fda3f3b2ec6c6dc11100cc8ddff07059692d697e Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Tue, 15 Jul 2025 14:06:17 -0700 Subject: [PATCH 234/457] [while_loop] fix constant tensor used as carried inputs (#158381) Address second part of #158366, where torch.tensor(0), is treated as a constant tensor and its .item() gets specailized to 0 which causes a silent specialization. The fix is to unspecialize the constant carries and make them non-constant. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158381 Approved by: https://github.com/zou3519 --- test/export/test_export.py | 35 ++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 38 +++++++++---------- torch/_higher_order_ops/while_loop.py | 41 +++++++++++++++++---- 3 files changed, 87 insertions(+), 27 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index dea000556960d..0f436d3af91f0 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -36,6 +36,7 @@ from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.scan import scan +from torch._higher_order_ops.while_loop import while_loop from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode from torch.export import ( @@ -1813,6 +1814,40 @@ def forward(self, x): ): export(M(), (torch.randn(2, 3),), strict=False) + @testing.expectedFailureTrainingIRToRunDecomp # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) + @testing.expectedFailureTrainingIRToRunDecompNonStrict # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) + @testing.expectedFailureRetraceability # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) + @testing.expectedFailureRetraceabilityNonStrict # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_while_loop_tensor_constant_idx(self): + def while_loop_decomp(x, y0): + out = torch.zeros_like(x) + + def cond_fn(idx, out, y0): + return idx < out.size(0) + + def body_fn(idx, out, y0): + i = idx.item() + torch._check_is_size(i, max=x.size(0) - 1) + y0 = x[i] + y0 + out = out.clone() + out[i] = y0 + return idx + 1, out, y0 + + cnt = torch.tensor(0) + _, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0]) + return out + + class TestModel(torch.nn.Module): + def forward(self, x, y0): + return while_loop_decomp(x, y0) + + x, y0 = torch.randn(16, 8), torch.randn(8) + exp_out = TestModel()(x, y0) + ep = export(TestModel(), (x, y0)) + out = ep.module()(x, y0) + self.assertEqual(exp_out, out) + def test_malformed_fqn_from_source_name(self): # See https://github.com/pytorch/pytorch/issues/141939 from types import MethodType diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7064d63945ebb..fbef41574a4f7 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1221,28 +1221,28 @@ def call_function( additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) with discard_graph_changes(tx): - # See NOTE [unspecialize int carry with unbacked symints] # Note: this must be run under discard graph changes. - def create_unbacked_sym_node_var(tx) -> SymNodeVariable: - example_value = _create_unbacked_symint( - tx.output.fake_mode, ignore_fresh_unbacked_symbols=True - ) - proxy = tx.output.current_tracer.create_graph_input( - "unbacked_symint", type(example_value), example_value - ) - return SymNodeVariable.create(tx, proxy, example_value) + def unspecialize_carried_inputs(tx, carry) -> VariableTracker: + # See NOTE [unspecialize int carry with unbacked symints] + if ( + isinstance(carry, ConstantVariable) and carry.python_type() is int + ) or isinstance(carry, SymNodeVariable): + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value + ) + return SymNodeVariable.create(tx, proxy, example_value) + else: + # See NOTE [unspecialize constant tensor carry] + assert isinstance(carry, TensorVariable) + cloned_carry = carry.clone() + cloned_carry.proxy.node.meta["example_value"].constant = None + return cloned_carry new_operands_seq = [ - ( - create_unbacked_sym_node_var(tx) - if ( - isinstance(carry, ConstantVariable) - and carry.python_type() is int - ) - or (isinstance(carry, SymNodeVariable)) - else carry - ) - for carry in operands_seq + unspecialize_carried_inputs(tx, carry) for carry in operands_seq ] # create cond subgrpahs diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d94ccf16d2168..68a8747ab4b82 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -17,6 +17,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, + disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -285,14 +286,38 @@ def _trace_while_loop( # iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the # unbacked symint output of subgraph as of today because this requires a smart range analysis. fake_mode: FakeTensorMode = _find_or_create_fake_mode() - unspecialized_carried_inputs = pytree.tree_map_only( - (int, torch.SymInt), - # For temporarily created unbacked symints, we don't need to bind them to any proxy - lambda _: _create_unbacked_symint( - fake_mode, ignore_fresh_unbacked_symbols=True - ), - carried_inputs, - ) + + def _unspecialize_carried_inputs(x): + if isinstance(x, (int, torch.SymInt)): + return _create_unbacked_symint( + fake_mode, ignore_fresh_unbacked_symbols=True + ) + # Note: [unspecialize constant tensor carry] + # We need to disable constant specialization for tensor inputs that become loop carries. + # Here's the problem: when a user creates a constant tensor e.g. torch.tensor(0), PyTorch calls aten.lift_fresh_copy + # to create a safe copy (avoiding aliasing issues), which creates a FakeTensor with constant=True. + # But when this FakeTensor becomes a loop carry, we have a problem: + # - Operations like .item() will read the constant value and bake it into the traced code + # - This is incorrect because carry variables change between loop iterations + # - The traced code would use the wrong constant value for all iterations + # Solution: We clone the constant tensors and mark the cloned tensor as non-constant so they won't + # be specialized to fixed values during tracing body_fn or cond_fn. + elif ( + isinstance(x, torch.Tensor) + and hasattr(x, "constant") + and x.constant is not None + ): + x = x.clone() + x.constant = None + return x + + with disable_proxy_modes_tracing(): + unspecialized_carried_inputs = pytree.tree_map_only( + (int, torch.SymInt, torch.Tensor), + # For temporarily created unbacked symints, we don't need to bind them to any proxy + lambda x: _unspecialize_carried_inputs(x), + carried_inputs, + ) cond_graph = reenter_make_fx(cond_fn)( *unspecialized_carried_inputs, *additional_inputs From a3396a9b855cc759d6669729dff3e0dc379bc79f Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 16 Jul 2025 22:50:59 -0700 Subject: [PATCH 235/457] [hop] set capture_scalar_outputs=True by default for compiled hops (#158480) We want to do it for two reasons: 1. It's tedious for users to manually turn on capture_scalar_outputs=True when compiling map and scan with inductor, where we decomposing them into while_loop and use the idx tensor.item() to select a slice of output buffer and write into it. This pr turns on the flag by default. 2. a graph break caused by capture_scalar_outputs=False would cause the hop to fail, and we should turn it on by default so that the error message is more meaningful. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158480 Approved by: https://github.com/zou3519 --- test/functorch/test_control_flow.py | 3 --- torch/_higher_order_ops/utils.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index fbcb8fc2e19b9..9c65aeedd8d97 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1409,7 +1409,6 @@ def f(x, y): f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5) ) - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_map_illegal_outputs(self): def f(x, y): return x.item() @@ -8034,7 +8033,6 @@ def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", ar @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend): m, args = WHILE_LOOP_TESTS["pytree_int_carry"] if backend == "eager": @@ -8196,7 +8194,6 @@ def _check_export_ret_graph_str(self, fn, args, dynamic_shapes=None) -> str: return normalize_gm(non_strict_ep.module().print_readable(print_output=False)) @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_cond_eager_run_with_item(self): class M(torch.nn.Module): def forward(self, a, b1, b2, c): diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 2000571f60574..bf3dc83f8608f 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -236,6 +236,7 @@ def diff_device( def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag _old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs + _old_capture_scalar_outputs = torch._dynamo.config.capture_scalar_outputs # The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds # the top-level frame produces no graph, the default behavior is to fallback to eager. # Then when it encounters an inner function, it will try to trace that function again, which is unnecessary. @@ -249,10 +250,12 @@ def _set_compilation_env(): # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False torch._dynamo.config.allow_empty_graphs = True + torch._dynamo.config.capture_scalar_outputs = True yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs + torch._dynamo.config.capture_scalar_outputs = _old_capture_scalar_outputs # The invariant here is that we always trace the branch with fake tensor From be896d6b41f560e59c87f9d28df109b1553139a4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 07:44:14 +0000 Subject: [PATCH 236/457] Revert "Forward-fix unused variables warning/error (#158549)" This reverts commit eeda1a75ace75ce8a6763050fb91d236a6d3287b. Reverted https://github.com/pytorch/pytorch/pull/158549 on behalf of https://github.com/jithunnair-amd due to Sorry, need to revert this first, so we can revert PR 158037, which broke ROCm CI ([comment](https://github.com/pytorch/pytorch/pull/158549#issuecomment-3087942475)) --- aten/src/ATen/cuda/CUDABlas.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index cf403365b2df2..acb1d5ed8b0da 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1993,8 +1993,8 @@ void scaled_gemm( // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, // but we must invoke get_scale_mode anyways to trigger the version checks. - [[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); - [[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); + int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); + int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); #if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); From 32aade9d8d39d58c33215f50afe5382458d70821 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 07:47:46 +0000 Subject: [PATCH 237/457] Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)" This reverts commit 39ac189808c61588f3594dbc2fc1d69bb6194c47. Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/jithunnair-amd due to Ignored ROCm failures while ROCm was unstable, but HUD clearly shows this PR introduced failures on trunk ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3087982975)) --- aten/src/ATen/ceil_div.h | 17 +- aten/src/ATen/cuda/CUDABlas.cpp | 116 +++------- aten/src/ATen/cuda/CUDABlas.h | 14 +- aten/src/ATen/cuda/tunable/GemmCommon.h | 8 +- aten/src/ATen/cuda/tunable/GemmHipblaslt.h | 63 ++---- aten/src/ATen/cuda/tunable/TunableGemm.h | 5 +- aten/src/ATen/native/cuda/Blas.cpp | 243 ++++++++++----------- test/test_matmul_cuda.py | 101 ++------- 8 files changed, 204 insertions(+), 363 deletions(-) diff --git a/aten/src/ATen/ceil_div.h b/aten/src/ATen/ceil_div.h index 777fc09a7049d..37d67b232a22c 100644 --- a/aten/src/ATen/ceil_div.h +++ b/aten/src/ATen/ceil_div.h @@ -7,15 +7,8 @@ namespace at { /** Computes ceil(a / b) */ -template < - typename Res = void, - typename T, - typename U, - typename = std::enable_if_t< - std::conjunction_v, std::is_integral>>> -C10_ALWAYS_INLINE C10_HOST_DEVICE - std::conditional_t, std::common_type_t, Res> - ceil_div(T a, U b) { +template >> +C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { return (a + b - 1) / b; } @@ -23,10 +16,8 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest multiple of b */ -template -C10_ALWAYS_INLINE C10_HOST_DEVICE - std::conditional_t, std::common_type_t, Res> - round_up(T a, U b) { +template +C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { return ceil_div(a, b) * b; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index acb1d5ed8b0da..d009520d05ab8 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1843,69 +1843,6 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) { - switch (scaling_type) { - case ScalingType::BlockWise1x32: - TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - - case ScalingType::BlockWise1x16: - TORCH_CHECK(scale_dtype == kFloat8_e4m3fn); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; -#else - TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above"); -#endif // if CUDA_VERSION >= 12080 - - case ScalingType::RowWise: - TORCH_CHECK(scale_dtype == kFloat); -#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F; -#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) - // Return the default, since in old hipblaslt this is activated via - // the SCALE_POINTER_VEC_EXT attributed. - return 0; -#else - TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - - case ScalingType::BlockWise1x128: - TORCH_CHECK(scale_dtype == kFloat); - TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling") -#if CUDA_VERSION >= 12090 - return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; -#else - TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - - case ScalingType::BlockWise128x128: - TORCH_CHECK(scale_dtype == kFloat); - TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling") -#if CUDA_VERSION >= 12090 - return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; -#else - TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above"); -#endif // if CUDA_VERSION >= 12090 - -case ScalingType::TensorWise: - TORCH_CHECK(scale_dtype == kFloat); -#if CUDA_VERSION >= 12080 - return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; -#else - // The macro isn't defined, thus we inline its value. - return 0; -#endif // if CUDA_VERSION >= 12080 - - default: - TORCH_CHECK(false); - return -1; - } -} - void scaled_gemm( char transa, char transb, @@ -1917,20 +1854,19 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, - ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, - ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum) { + bool use_fast_accum, + bool use_rowwise) { // Note: see `cublasCommonArgs` for various non-intuitive manupulations // of input arguments to this function. #if CUDA_VERSION >= 11080 || defined(USE_ROCM) @@ -1943,15 +1879,19 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER; cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER; - // hipblaslt supported row-wise before cublas, and did so their own way (via - // the SCALE_POINTERSs), but then migrated to match how cublas does it (via - // the SCALE_MODEs). Here we check for this early custom mode. -#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) - if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) { +#if defined(USE_ROCM) +#if defined(HIPBLASLT_OUTER_VEC) + // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F +#elif defined(HIPBLASLT_VEC_EXT) + if (use_rowwise) { matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; } -#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT) +#else + // rowwise isn't supported using older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt"); +#endif +#endif // defined(USE_ROCM) computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); if (result_scale_ptr != nullptr) { @@ -1991,14 +1931,30 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); } - // The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt, - // but we must invoke get_scale_mode anyways to trigger the version checks. - int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum); - int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum); -#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode); -#endif + if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) { +#if CUDA_VERSION >= 12080 + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) { +#if CUDA_VERSION >= 12080 + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3); +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above"); +#endif // if CUDA_VERSION >= 12080 + } else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) { +#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC)) + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); +#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT) + // no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below +#else + TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above"); +#endif // if CUDA_VERSION >= 12090 + } CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 5021917fe0950..b1dac2162dc42 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -136,15 +136,6 @@ void int8_gemm( int32_t* result_ptr, int64_t result_ld); -enum class ScalingType : std::uint8_t { - TensorWise, // fp32 scales - RowWise, // fp32 scales - BlockWise1x16, // fp8_e4m3fn scales - BlockWise1x32, // fp8_e8m0fnu scales - BlockWise1x128, // fp32 scales - BlockWise128x128, // fp32 scales -}; - void scaled_gemm( char transa, char transb, @@ -156,20 +147,19 @@ void scaled_gemm( int64_t mat1_ld, ScalarType mat1_dtype, ScalarType mat1_scale_dtype, - ScalingType mat1_scaling_type, const void* mat2_ptr, const void* mat2_scale_ptr, int64_t mat2_ld, ScalarType mat2_dtype, ScalarType mat2_scale_dtype, - ScalingType mat2_scaling_type, const void* bias_ptr, ScalarType bias_dtype, void* result_ptr, const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - bool use_fast_accum); + bool use_fast_accum, + bool use_rowwise); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype) diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index 6d19907aba4ad..6f896f1a22bfc 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -29,8 +29,6 @@ namespace at::cuda::tunable { -using at::cuda::blas::ScalingType; - enum class BlasOp { N = 0, T = 1 @@ -600,8 +598,7 @@ struct ScaledGemmParams : OpParams { // // In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector. return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s", - transa, transb, m, n, k, lda, ldb, ldc, - a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise, + transa, transb, m, n, k, lda, ldb, ldc, use_rowwise, bias_ptr == nullptr ? "None" : at::toString(bias_dtype)); } @@ -676,13 +673,11 @@ struct ScaledGemmParams : OpParams { int64_t lda{}; ScalarType a_dtype{}; ScalarType a_scale_dtype{}; - ScalingType a_scaling_type{}; const void* b{}; const void* b_scale_ptr{}; int64_t ldb{}; ScalarType b_dtype{}; ScalarType b_scale_dtype{}; - ScalingType b_scaling_type{}; const void* bias_ptr{}; ScalarType bias_dtype{}; void* c{}; @@ -691,6 +686,7 @@ struct ScaledGemmParams : OpParams { ScalarType c_dtype{}; void* amax_ptr{}; bool use_fast_accum{}; + bool use_rowwise{}; private: bool duplicate_inputs_{false}; }; diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 809ba51009f0a..32fb7c2774fff 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -206,43 +206,23 @@ float GetBetaFromParams(const ScaledGemmParams* params) { } template -ScalingType GetAScalingTypeFromParams(const GemmParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmParams* params) { + return false; } template -ScalingType GetBScalingTypeFromParams(const GemmParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmAndBiasParams* params) { + return false; } template -ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams* params) { - return ScalingType::TensorWise; +bool GetUseRowwiseFromParams(const GemmStridedBatchedParams* params) { + return false; } template -ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams* params) { - return ScalingType::TensorWise; -} - -template -ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { - return params->a_scaling_type; -} - -template -ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { - return params->b_scaling_type; +bool GetUseRowwiseFromParams(const ScaledGemmParams* params) { + return params->use_rowwise; } template @@ -509,24 +489,23 @@ class HipblasltGemmOp : public Callable { const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { - hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; - hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; - if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { -#if defined(HIPBLASLT_OUTER_VEC) - matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(HIPBLASLT_VEC_EXT) - a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; +#ifdef HIPBLASLT_VEC_EXT + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr); + } + else #endif + { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); } - if (GetBScalingTypeFromParams(params) == ScalingType::RowWise) { -#if defined(HIPBLASLT_OUTER_VEC) +#ifdef HIPBLASLT_OUTER_VEC + if (GetUseRowwiseFromParams(params)) { + matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); -#elif defined(HIPBLASLT_VEC_EXT) - b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; -#endif } - matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); - matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); +#endif } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index d941c230630c4..d7e2835b1b109 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -96,20 +96,19 @@ class DefaultScaledGemmOp : public Callable> { params->lda, params->a_dtype, params->a_scale_dtype, - params->a_scaling_type, params->b, params->b_scale_ptr, params->ldb, params->b_dtype, params->b_scale_dtype, - params->b_scaling_type, params->bias_ptr, params->bias_dtype, params->c, params->c_scale_ptr, params->ldc, params->c_dtype, - params->use_fast_accum); + params->use_fast_accum, + params->use_rowwise); return OK; } }; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 60becebfb81e5..c46e1cc633119 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -100,7 +99,6 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } } -using at::cuda::blas::ScalingType; /** * @brief Prepares matrices for CUBLAS operation @@ -142,9 +140,7 @@ struct cublasCommonArgs { Tensor& c, const std::optional& scale_a = std::nullopt, const std::optional& scale_b = std::nullopt, - const std::optional& scale_result = std::nullopt, - const std::optional& scaling_choice_a = std::nullopt, - const std::optional& scaling_choice_b = std::nullopt) { + const std::optional& scale_result = std::nullopt) { bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); @@ -156,10 +152,8 @@ struct cublasCommonArgs { // as B.T @ A.T, check transpose_result to determine if we flip the scales scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); - scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); - scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; } if (scale_result) { @@ -205,9 +199,7 @@ struct cublasCommonArgs { void* scale_matb_ptr = nullptr; void* scale_result_ptr = nullptr; std::optional scale_mata_dtype; - std::optional scaling_mata_type; std::optional scale_matb_dtype; - std::optional scaling_matb_type; std::optional scale_result_dtype; }; } // namespace @@ -1083,114 +1075,133 @@ static bool _scaled_mm_is_fnuz() { namespace{ +enum class ScalingType : std::uint8_t { + TensorWise, + RowWise, + BlockWise, + Error +}; /* * Scaling Type Determination: * --------------------------- * Conditions and corresponding Scaling Types: * - * - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`: + * - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`: * - Returns BlockWise (with additional size checks). * - * - Else if scale.numel() == 1: + * - If scale_a.numel() == 1 && scale_b.numel() == 1: * - Returns TensorWise. * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1: + * - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n: * - Returns RowWise. * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128: - * - Returns BlockWise 1x128. - * - * - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128: - * - Returns BlockWise 128x128. - * * - Otherwise: * - Returns Error. */ -using at::cuda::blas::ScalingType; - -bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { - return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1; -} - -bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == t.size(0) && scale.size(1) == 1 - && scale.is_contiguous()); -} - -// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales -bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) { - // Multiply t.size(1) by 2 to adjust for fp4x2 packing - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn - && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4) - && scale.is_contiguous()); -} - -// 1x32 blocks for microscaled fp8 data and fp8_e8m0fnu scales -bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) { - // TODO: We might want to enforce some structure on the shapes of the scale - // tensors - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu - && scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4) - && scale.is_contiguous()); -} - -bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128) - && scale.stride(0) == 1 && scale.stride(1) == t.size(0)); -} - -bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) { - return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2 - && scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128) - && scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1); -} +// Validates the scale tensors to scaled_mm +// And returns the type of scaling/which kernel to use +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_k, + int64_t dim_n) { + // Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types) + if ((scale_a.scalar_type() == scale_b.scalar_type()) && + ((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) { + const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn; + + // cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements + // cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements. + const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32; + + constexpr int64_t BLOCK_SIZE_MN = 128; + + // adjust for fp4x2 packing if necessary + const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k; + + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; + auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K); + auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4; + + // TODO: We might want to enforce some structure on the shapes of the scale + // tensors + + // Check expected sizes for block-wise scaling + auto expected_a_size = + BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks; + auto expected_b_size = + BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks; + + TORCH_CHECK(scale_a.numel() == expected_a_size, + "For BlockWise scaling: Expected scale_a size to be ", + expected_a_size, " but got ", scale_a.numel()); + TORCH_CHECK(scale_b.numel() == expected_b_size, + "For BlockWise scaling: Expected scale_b size to be ", + expected_b_size, " but got ", scale_b.numel()); + + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "For BlockWise scaling: Both scale_a and scale_b must be contiguous"); + + return ScalingType::BlockWise; + } + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); -bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) { - switch (desired_scaling) { - case ScalingType::TensorWise: - return is_tensorwise_scaling(t, scale); - case ScalingType::RowWise: - return is_rowwise_scaling(t, scale); - case ScalingType::BlockWise1x16: - return is_blockwise_1x16_scaling(t, scale); - case ScalingType::BlockWise1x32: - return is_blockwise_1x32_scaling(t, scale); - case ScalingType::BlockWise1x128: - return is_blockwise_1x128_scaling(t, scale); - case ScalingType::BlockWise128x128: - return is_blockwise_128x128_scaling(t, scale); - default: - TORCH_CHECK(false); - return false; + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; } -} -std::pair get_joint_scaling( - std::initializer_list> options, - const at::Tensor& a, const at::Tensor& b, - const at::Tensor& scale_a, const at::Tensor& scale_b) { - for (auto [lhs, rhs] : options) { - if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) { - return {lhs, rhs}; - } + // For non-TensorWise scaling, enforce 2D input tensors + TORCH_CHECK( + scale_a.dim() == 2 && scale_b.dim() == 2, + "For non-TensorWise scaling, scale tensors must be 2-dimensional, " + "but got scale_a.dim()=", + scale_a.dim(), + " and scale_b.dim()=", + scale_b.dim()); + + // Check for RowWise scaling + if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && + scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { +#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \ + (defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC))) + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous for RowWise scaling."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif } + + // If we reach here, the input doesn't match any valid scaling type TORCH_CHECK( - false, - "Invalid scaling configuration.\n" - "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n" - "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n" - "- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n" - "- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n" - "- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n" - "- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n" - "Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ", - "b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides() - ); + false, + "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " + "For RowWise scaling, scale_a should be (", + dim_m, + ", 1) and scale_b should be (1, ", + dim_n, + "). " + "Got scale_a.size()=(", + scale_a.size(0), + ", ", + scale_a.size(1), + ") and ", + "scale_b.size()=(", + scale_b.size(0), + ", ", + scale_b.size(1), + ")"); + + return ScalingType::Error; } } // namespace @@ -1208,8 +1219,8 @@ std::pair get_joint_scaling( // - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` // - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` // - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type -// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme -// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme +// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type // - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type // - `use_fast_accum`: if true, enables fast float8 accumulation // - `out`: a reference to the output tensor @@ -1232,21 +1243,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - // Check what type of scaling we are doing based on inputs. This list is sorted - // by decreasing priority. We prefer "simpler" schemes as they are supported - // more broadly (more GPU archs, more CUDA versions) and because they are more - // efficient. This tends to matter only for small matmuls (e.g., 1x1x128). - auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( - { - std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), - std::make_pair(ScalingType::RowWise, ScalingType::RowWise), - std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128), - std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128), - std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128), - std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32), - std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16) - }, - mat1, mat2, scale_a, scale_b); + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), "scale_result must be a float scalar"); @@ -1317,7 +1316,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, #ifndef USE_ROCM // We are doing row-wise scaling auto dprops = at::cuda::getCurrentDeviceProperties(); - if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise + if (scaling_choice == ScalingType::RowWise && (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) { TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( @@ -1331,7 +1330,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, return out; } #else - if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { + if (scaling_choice == ScalingType::RowWise) { // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. Tensor b = mat2; if (_scaled_mm_is_fnuz()) { @@ -1346,7 +1345,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } #endif - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result); const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); @@ -1423,14 +1422,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.a_scale_ptr = args.scale_mata_ptr; params.lda = args.lda; params.a_dtype = args.mata->scalar_type(); - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.a_scaling_type = args.scaling_mata_type.value(); params.b = args.matb->data_ptr(); params.b_scale_ptr = args.scale_matb_ptr; params.ldb = args.ldb; params.b_dtype = args.matb->scalar_type(); - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.b_scaling_type = args.scaling_matb_type.value(); params.bias_ptr = bias ? bias->data_ptr(): nullptr; params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; params.c = args.result->data_ptr(); @@ -1438,6 +1433,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.ldc = args.result_ld; params.c_dtype = out_dtype_; params.use_fast_accum = use_fast_accum; + params.use_rowwise = scaling_choice == ScalingType::RowWise; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) } @@ -1471,20 +1467,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, args.lda, args.mata->scalar_type(), args.scale_mata_dtype.value(), - args.scaling_mata_type.value(), args.matb->data_ptr(), args.scale_matb_ptr, args.ldb, args.matb->scalar_type(), args.scale_matb_dtype.value(), - args.scaling_matb_type.value(), bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), args.scale_result_ptr, args.result_ld, out_dtype_, - use_fast_accum); + use_fast_accum, + scaling_choice == ScalingType::RowWise); } return out; diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 30526c2a84826..31f36681bc3a4 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -785,7 +785,7 @@ def amax_to_scale( if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: - res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -806,20 +806,6 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): return amax_to_scale(amax, float8_dtype, x.dtype) -def tensor_to_scale_block( - x: torch.Tensor, - float8_dtype: torch.dtype, - block_outer: int, - block_inner: int, -) -> tuple[torch.Tensor, torch.Tensor]: - x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer)) - amax = x.abs().amax(dim=[1, 3], keepdim=True).float() - scale = torch.finfo(float8_dtype).max / amax - x = x.mul(scale).to(float8_dtype) - x = x.flatten(2, 3).flatten(0, 1) - scale = scale.flatten(2, 3).flatten(0, 1) - return x, scale - def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -828,17 +814,6 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: return out_fp32.to(out_dtype) -def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: - x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1)) - y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1)) - x_fp32 = x.to(torch.float) / x_scale[:, None, :, None] - y_fp32 = y.to(torch.float) / y_scale[:, None, :, None] - x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1) - y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1) - out_fp32 = torch.mm(x_fp32, y_fp32) - - return out_fp32.to(out_dtype) - def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, @@ -1262,7 +1237,11 @@ def test_float8_error_messages(self, device) -> None: y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + "For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" + ), ): torch._scaled_mm( x_fp8, @@ -1273,7 +1252,11 @@ def test_float8_error_messages(self, device) -> None: ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + " For RowWise scaling, scale_a should be (1024, 1) and scale_b " + "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" + ), ): torch._scaled_mm( x_fp8, @@ -1283,18 +1266,22 @@ def test_float8_error_messages(self, device) -> None: out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), - scale_b=torch.ones((N, N, 1), device="cuda"), + scale_b=torch.ones((N, N), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( - RuntimeError, re.escape("Invalid scaling configuration") + RuntimeError, + re.escape( + "Both scale_a and scale_b must be contiguous for RowWise scaling." + ), ): torch._scaled_mm( x_fp8, @@ -1359,58 +1346,6 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+") - @unittest.skipIf( - _get_torch_cuda_version() < (12, 9), - "cuBLAS blockwise scaling added in CUDA 12.9", - ) - @parametrize("output_dtype", [torch.bfloat16, torch.float32]) - @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) - def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block): - torch.manual_seed(42) - - x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3) - y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3) - - x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128) - y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128) - - # 1x128 blocks need scales to be outer-dim-major - if lhs_block == 1: - x_scales = x_scales.t().contiguous().t() - if rhs_block == 1: - y_scales = y_scales.t().contiguous().t() - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated_block( - x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype - ) - - cosine_sim = torch.nn.functional.cosine_similarity( - out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0 - ) - self.assertGreaterEqual(float(cosine_sim), 0.999) - - if output_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 6e-1, 7e-2 - else: - atol, rtol = 7e-1, 2e-3 - - self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - # One last check against the full-precision reference, to ensure we - # didn't mess up the scaling itself and made the test trivial. - cosine_sim = torch.nn.functional.cosine_similarity( - out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0 - ) - self.assertGreaterEqual(float(cosine_sim), 0.999) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) From ead80f3202b23ad16daa3a250754ddb91d64e9f8 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Fri, 18 Jul 2025 09:13:41 +0000 Subject: [PATCH 238/457] =?UTF-8?q?Fix=20s390x=20CI:=20ensure=20that=20all?= =?UTF-8?q?=20python=20dependencies=20are=20installed=20when=20=E2=80=A6?= =?UTF-8?q?=20(#158552)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …building pytorch for tests on s390x Pull Request resolved: https://github.com/pytorch/pytorch/pull/158552 Approved by: https://github.com/huydhn --- .github/workflows/_linux-build.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index f1e2f917f4bc8..bce807018272b 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -323,6 +323,11 @@ jobs: "${USED_IMAGE}" \ ${DOCKER_SHELL_CMD} ) + + if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then + docker exec -t "${container_name}" sh -c "python3 -m pip install -r requirements.txt" + fi + docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh' END_TIME=$(date +%s) From 7b05bdd925f0f4b49e68662f9761fabaa27f2faf Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 16:33:40 -0700 Subject: [PATCH 239/457] [DTensor] fix copy_ strategy (#158538) The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158538 Approved by: https://github.com/zpcore, https://github.com/XilunWu, https://github.com/wanchaol ghstack dependencies: #158495, #158490 --- test/distributed/tensor/test_tensor_ops.py | 44 ++++++++------- torch/distributed/tensor/_ops/_tensor_ops.py | 56 ++++++-------------- 2 files changed, 42 insertions(+), 58 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 9140d2f5aae13..d62da27d43393 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -56,10 +56,11 @@ def test_clone(self): @with_comms def test_copy_(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - src_specs = [[Replicate()], [Shard(0)]] - src_tensor = torch.randn((12, 12)) + # basic test + src_tensor = torch.randn((12, 12)) dst_tensor = torch.zeros(12, 12) + src_specs = [[Replicate()], [Shard(0)]] dst_specs = [[Replicate()], [Shard(0)]] for dst_spec, src_spec in zip(dst_specs, src_specs): src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) @@ -68,22 +69,29 @@ def test_copy_(self): dst_tensor.copy_(src_tensor) self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - # @pytest.mark.xfail - # @with_comms - # def test_copy_broadcast(self): - # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - # src_specs = [[Replicate()], [Shard(0)]] - # src_tensor = torch.randn((12,)) - - # dst_tensor = torch.zeros(12, 12) - # dst_specs = [[Replicate()], [Shard(1)]] - # for dst_spec, src_spec in zip(dst_specs, src_specs): - # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) - # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) - # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case - # dst_dtensor.copy_(src_dtensor) - # dst_tensor.copy_(src_tensor) - # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + # simple broadcasting + src_tensor = torch.randn((128,)) + dst_tensor = torch.zeros(128, 128) + src_specs = [[Replicate()], [Shard(0)]] + dst_specs = [[Replicate()], [Shard(1)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen + src_tensor = torch.randn((64, 1)) + dst_tensor = torch.zeros(16, 32, 64, 128) + src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]] + dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) @with_comms def test_contiguous(self): diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fd6621ab75124..e53eef1610162 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -35,6 +35,8 @@ Shard, ) +from ._pointwise_ops import pointwise_strategy + aten = torch.ops.aten @@ -91,46 +93,20 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) )(propagate_single_input_strategy) - -@register_op_strategy(aten.copy_.default) -def copy_strategy(op_schema: OpSchema) -> StrategyType: - # TODO: this strategy is incorrect for copy_ in the case that src tensor - # is smaller rank than self tensor. It is possible to select a strategy from self tensor - # that is invalid for dst tensor. - # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we - # may broadcast a new dim to the left or right of 0 when copying. - # - # For now, I just keep copy working essentially the way it was before this PR, - # but split it out so it can be handled separately in the future. - num_tensor_args = 2 - first_input_strategy = op_schema.args_schema[0] - assert isinstance(first_input_strategy, OpStrategy) - return OpStrategy( - [ - OpSpec( - output_specs=DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=[ - DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ) - for _ in range(num_tensor_args) - ], - redistribute_cost=[ - generate_redistribute_costs( - first_input_strategy, strategy.output_spec - ) - for _ in range(num_tensor_args) - ], - ) - for strategy in first_input_strategy.strategies - ] - ) +# copy_ is actually a pointwise op with broadcasting, so reuse the pointwise strategy, which takes care of these +# requirements. +# +# Following torch broadcasting semantics (https://docs.pytorch.org/docs/stable/notes/broadcasting.html) +# - self can not change shape as a result of broadcasting since this is an inplace op +# - src can broadcast, but when it does it always does so from the trailing end +# e.g. the last dim of 'src' must match up with the last dim of 'self' +# +# DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input. +# In practice, what this means is +# - our output strategies should map 1:1 to our 'self' input strategies +# - our 'src' input may be redistributed to match up with the 'self' input, with the caveat of adjusting for +# broadcasting dim +register_op_strategy(aten.copy_.default)(pointwise_strategy) @register_op_strategy( From 27af877f8459988496d47b6e22d80d98c1e80581 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Fri, 18 Jul 2025 09:59:38 +0000 Subject: [PATCH 240/457] [ATen][CUDA][SDPA] Flash Attention: Refactor sm version checks (#158558) The architecture version checks are unnecessary fine-grained in PyTorch. Considering the fact that PyTorch's Flash Attention works on all `sm_80+` machines, it makes more sense to just check for lower bound. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158558 Approved by: https://github.com/eqy --- .../cuda/flash_attn/flash_api.cpp | 67 +++++-------------- 1 file changed, 18 insertions(+), 49 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 854c33dec7342..68451ba5ffcc8 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -389,20 +389,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -577,20 +571,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -838,15 +826,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -855,7 +836,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -885,7 +866,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1055,15 +1036,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); + bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1071,7 +1046,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -1106,7 +1081,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1280,20 +1255,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); From a4ec381302f8acd279033707b182bed30ffd2091 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 18 Jul 2025 13:23:09 +0800 Subject: [PATCH 241/457] [build] pin `setuptools>=77` to enable PEP 639 (#158104) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158104 Approved by: https://github.com/rgommers, https://github.com/Skylion007, https://github.com/atalman --- .ci/docker/manywheel/Dockerfile_2_28 | 2 +- .ci/docker/manywheel/Dockerfile_s390x | 5 ++--- .ci/docker/requirements-ci.txt | 7 ++++--- .ci/pytorch/build.sh | 3 +++ .ci/pytorch/win-test-helpers/build_pytorch.bat | 5 +++++ .ci/pytorch/win-test.sh | 2 +- .ci/pytorch/windows/internal/install_python.bat | 2 +- .ci/pytorch/windows/setup_build.bat | 5 ++++- .ci/wheel/build_wheel.sh | 14 +++++++------- .github/requirements/pip-requirements-macOS.txt | 6 +++--- .github/scripts/lintrunner.sh | 2 +- .github/scripts/windows/build_triton.bat | 2 +- pyproject.toml | 11 +++-------- requirements-build.txt | 4 ++-- test/dynamo/test_exc.py | 16 ++++++++-------- 15 files changed, 46 insertions(+), 40 deletions(-) diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index b150423e99544..7f279a1c1a735 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -128,7 +128,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH # Install setuptools and wheel for python 3.12/3.13 RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ - /opt/python/${cpython_version}/bin/python -m pip install setuptools wheel; \ + /opt/python/${cpython_version}/bin/python -m pip install "setuptools>=77.0.0" "packaging>=24.2" wheel; \ done; diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 46ec7f77ae8ba..335488b88f122 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -124,10 +124,9 @@ RUN python3 -mpip install cmake==3.28.0 # install newest flatbuffers version first: # for some reason old version is getting pulled in otherwise. # packaging package is required for onnxruntime wheel build. -RUN pip3 install flatbuffers && \ - pip3 install cython 'pkgconfig>=1.5.5' 'setuptools>=77' 'numpy<2.3.0' && \ +RUN pip3 install 'setuptools>=77.0' 'packaging>=24.2' && \ + pip3 install flatbuffers cython 'pkgconfig>=1.5.5' 'numpy<2.3.0' && \ pip3 install --no-build-isolation h5py==3.11.0 && \ - pip3 install packaging && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 9c8251989477d..650c4e58c8ba6 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -50,7 +50,7 @@ flatbuffers==24.12.23 hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests -#Pinned versions: 3.44.6, 4.53.2 +#Pinned versions: 5.35.1 #test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py junitparser==2.1.1 @@ -307,7 +307,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.12.6.0 +z3-solver==4.15.1.0 #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -363,9 +363,10 @@ pwlf==2.2.1 # To build PyTorch itself +packaging>=24.2 pyyaml pyzstd -setuptools>=70.1.0 +setuptools>=77.0.0 six scons==4.5.2 ; platform_machine == "aarch64" diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 994bd179e4649..07bf2037f430d 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -269,6 +269,9 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" //... fi else + # install build-system requirements before running setup.py commands + python -m pip install -r requirements-build.txt + # check that setup.py would fail with bad arguments echo "The next three invocations are expected to fail with invalid command error messages." ( ! get_exit_code python setup.py bad_argument ) diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 7ceb425ce2d1a..74c9183f2abb0 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -126,6 +126,11 @@ if "%USE_CUDA%"=="1" ( set CMAKE_CUDA_COMPILER_LAUNCHER=%TMP_DIR%/bin/randomtemp.exe;%TMP_DIR%\bin\sccache.exe ) +:: Install build-system requirements before running setup.py commands +python -m pip install -r requirements-build.txt +if errorlevel 1 goto fail +if not errorlevel 0 goto fail + :: Print all existing environment variable for debugging set diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index b61dd06ef562c..be7f3e4bb35cc 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -41,7 +41,7 @@ fi python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 # Install Z3 optional dependency for Windows builds. -python -m pip install z3-solver==4.12.2.0 +python -m pip install z3-solver==4.15.1.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.30 diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 73622bd736edd..65405a875b6b8 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -18,5 +18,5 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" -%PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel +%PYTHON_EXEC% -m pip install --upgrade pip "setuptools>=77.0.0" "packaging>=24.2" wheel if errorlevel 1 exit /b 1 diff --git a/.ci/pytorch/windows/setup_build.bat b/.ci/pytorch/windows/setup_build.bat index 9b492eef664d7..df925b4ba90bc 100644 --- a/.ci/pytorch/windows/setup_build.bat +++ b/.ci/pytorch/windows/setup_build.bat @@ -7,6 +7,9 @@ call "internal\install_python.bat" %PYTHON_EXEC% --version set "PATH=%CD%\Python\Lib\site-packages\cmake\data\bin;%CD%\Python\Scripts;%CD%\Python;%PATH%" + +%PYTHON_EXEC% -m pip install "setuptools>=77.0.0" "packaging>=24.2" + if "%DESIRED_PYTHON%" == "3.13t" %PYTHON_EXEC% -m pip install numpy==2.2.1 cmake if "%DESIRED_PYTHON%" == "3.13" %PYTHON_EXEC% -m pip install numpy==2.1.2 cmake if "%DESIRED_PYTHON%" == "3.12" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake @@ -16,7 +19,7 @@ if "%DESIRED_PYTHON%" == "3.9" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake %PYTHON_EXEC% -m pip install pyyaml %PYTHON_EXEC% -m pip install mkl-include mkl-static -%PYTHON_EXEC% -m pip install boto3 ninja typing_extensions setuptools==72.1.0 +%PYTHON_EXEC% -m pip install boto3 ninja typing-extensions where cmake.exe diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 878d6595c84c0..dc44f8ccc2922 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -127,7 +127,7 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages export MACOSX_DEPLOYMENT_TARGET=10.15 export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -SETUPTOOLS_PINNED_VERSION="==70.1.0" +SETUPTOOLS_PINNED_VERSION="==77.0.0" PYYAML_PINNED_VERSION="=5.3" EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -135,7 +135,7 @@ RENAME_WHEEL=true case $desired_python in 3.13t) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" CONDA_ENV_CREATE_FLAGS="python-freethreading" @@ -145,31 +145,31 @@ case $desired_python in ;; 3.13) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" ;; 3.12) echo "Using 3.12 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.11) echo "Using 3.11 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.10) echo "Using 3.10 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.9) echo "Using 3.9 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index e8464f0a55ff5..7eaa962995b79 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -12,7 +12,7 @@ numba==0.59.0 numpy==1.26.4 opt-einsum>=3.3 optree==0.13.0 -packaging==23.1 +packaging==25.0 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 @@ -26,11 +26,11 @@ pytest-xdist==3.3.1 pytest==7.3.2 pyyaml==6.0.2 scipy==1.12.0 -setuptools==72.1.0 +setuptools==80.9.0 sympy==1.13.3 tlparse==0.3.30 tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 -z3-solver==4.12.2.0 +z3-solver==4.15.1.0 diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index ef4741444f942..1411ff0397b53 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -2,7 +2,7 @@ set -ex # Use uv to speed up lintrunner init -python3 -m pip install uv==0.1.45 setuptools +python3 -m pip install -U uv setuptools CACHE_DIRECTORY="/tmp/.lintbin" # Try to recover the cached binaries diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index 97cd535a49889..da2e86b40432a 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -10,7 +10,7 @@ if "%PY_VERS%" == "3.13t" ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) :: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==78.1.1 ninja dir "%VC_INSTALL_PATH%" diff --git a/pyproject.toml b/pyproject.toml index b41ae87621f0f..133da9289f5c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,12 @@ [build-system] requires = [ - # 70.1.0: min version for integrated bdist_wheel command from wheel package # 77.0.0: min version for SPDX expression support for project.license - "setuptools>=70.1.0,<80.0", + "setuptools>=77.0.0,<80.0", "cmake>=3.27", "ninja", "numpy", - "packaging", + "packaging>=24.2", "pyyaml", "requests", "six", # dependency chain: NNPACK -> PeachPy -> six @@ -21,11 +20,7 @@ name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" requires-python = ">=3.9,<3.14" -# TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 -# FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. -# TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed -# to an error on 2026.02.18. See also: https://github.com/pypa/setuptools/issues/4903 -license = { text = "BSD-3-Clause" } +license = "BSD-3-Clause" authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] keywords = ["pytorch", "machine learning"] classifiers = [ diff --git a/requirements-build.txt b/requirements-build.txt index be19d987f73db..12332b0e1af01 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,9 +1,9 @@ # Build System requirements -setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 +setuptools>=77.0.0,<80.0 # setuptools develop deprecated on 80.0 cmake>=3.27 ninja numpy -packaging +packaging>=24.2 pyyaml requests six # dependency chain: NNPACK -> PeachPy -> six diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index acc3fd55f6fb0..c340a2882d471 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -251,13 +251,13 @@ def fn(x, shape): Model: ==> L['shape'][0]: 0 - ==> L['shape'][1]: 1 - ==> L['shape'][2]: 1 + ==> L['shape'][1]: 0 + ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 - ==> s3: 1 - ==> s52: 1 + ==> s3: 0 + ==> s52: 0 ==> s77: 3 ==> s86: 0 @@ -315,16 +315,16 @@ def fn(x, shape): %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) Model: - ==> L['shape'][0]: 1 - ==> L['shape'][1]: 1 + ==> L['shape'][0]: 0 + ==> L['shape'][1]: 0 ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s3: 0 - ==> s52: 1 + ==> s52: 0 ==> s77: 3 - ==> s86: 1 + ==> s86: 0 Assertions: ==> (== 0 L['x'].storage_offset()) From 0eae6b68f424c0fade1e3db0ba179ae8c9f5ad25 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 17 Jul 2025 08:52:16 -0700 Subject: [PATCH 242/457] Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (#158537) Fixes #158376 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158537 Approved by: https://github.com/atalman --- aten/src/ATen/ScalarOps.cpp | 23 ++++++++++++++++++++++- test/dynamo/test_misc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 693fb46e639f2..da4f7a35a2f47 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -8,7 +8,28 @@ namespace at { namespace { template inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { - auto value = value_scalar.to(); + scalar_t value{}; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // relaxed float cast: allow inf similar to the torch.tensor constructor + // + // without this, we had the following divergence: + // torch.tensor(1123581321.0, dtype=torch.float16) + // => tensor(inf, dtype=torch.float16) + // torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16) + // => RuntimeError: value cannot be converted to type at::Half without overflow + + value = static_cast(value_scalar.to()); + } else { + value = value_scalar.to(); + } + scalar_t* dptr = static_cast(self.data_ptr()); *dptr = value; } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 632ebdc39278a..b8d759c66e302 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12962,6 +12962,38 @@ def f(actions, n_act, epsilon=0.1): y = torch.tensor(5) f(x, y) + def test_dynamic_float_scalar_tensor_coersion(self): + # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 + class Foo: + def __init__(self): + self.config = type( + "Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6} + ) + + @torch.compile(fullgraph=True) + def forward(self, input): + outputs = torch.where( + torch.abs(input - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=input.dtype, device=input.device + ), + torch.tensor( + self.config.pad_val + 1, dtype=input.dtype, device=input.device + ), + ) + return outputs + + foo = Foo() + inputs = torch.randn(3, 4) + result = foo.forward(inputs) + + original_pad_val = foo.config.pad_val + foo.config.pad_val += 1.0 + result2 = foo.forward(inputs) + + # Previously would crash with: + # RuntimeError: value cannot be converted to type at::Half without overflow + devices = ("cuda", "hpu", "xpu") instantiate_device_type_tests( From e882c761dd2bd2f32d1ac5b6f846c9951564e9e7 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 18 Jul 2025 07:32:02 -0700 Subject: [PATCH 243/457] Add STD_TORCH_CHECK to headeronly (#158377) Differential Revision: [D78366519](https://our.internmc.facebook.com/intern/diff/D78366519/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158377 Approved by: https://github.com/albanD --- c10/util/Exception.h | 21 +---- test/cpp/aoti_abi_check/test_exception.cpp | 19 +++++ .../libtorch_agnostic/csrc/kernel.cpp | 3 + torch/header_only_apis.txt | 3 + torch/headeronly/CMakeLists.txt | 1 + torch/headeronly/macros/build.bzl | 2 +- torch/headeronly/ovrsource_defs.bzl | 1 + torch/headeronly/util/Exception.h | 83 +++++++++++++++++++ 8 files changed, 112 insertions(+), 21 deletions(-) create mode 100644 test/cpp/aoti_abi_check/test_exception.cpp create mode 100644 torch/headeronly/util/Exception.h diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 3ff3396f5f1b8..8136896d07f88 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -365,26 +365,7 @@ C10_API std::string GetExceptionString(const std::exception& e); // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly #define C10_EXPAND_MSVC_WORKAROUND(x) x -// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases -// where the unlikely expression may be a constant, use this macro to ensure -// return statement analysis keeps working (at the cost of not getting the -// likely/unlikely annotation on nvcc). -// https://github.com/pytorch/pytorch/issues/21418 -// -// Currently, this is only used in the error reporting macros below. If you -// want to use it more generally, move me to Macros.h -// -// TODO: Brian Vaughan observed that we might be able to get this to work on -// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs -// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY -// in nvcc is causing us perf problems, this is not yet implemented, but this -// might be an interesting piece of C++ code for an intrepid bootcamper to -// write. -#if defined(__CUDACC__) -#define C10_UNLIKELY_OR_CONST(e) e -#else -#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) -#endif +#include // ---------------------------------------------------------------------------- // Error reporting macros diff --git a/test/cpp/aoti_abi_check/test_exception.cpp b/test/cpp/aoti_abi_check/test_exception.cpp new file mode 100644 index 0000000000000..74a9fee5d9863 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_exception.cpp @@ -0,0 +1,19 @@ +#include + +#include + +namespace torch { +namespace aot_inductor { + +TEST(TestExceptions, TestStdTorchCheck) { + EXPECT_NO_THROW(STD_TORCH_CHECK(true, "dummy true message")); + EXPECT_NO_THROW(STD_TORCH_CHECK(true, "dummy ", "true ", "message")); + EXPECT_THROW( + STD_TORCH_CHECK(false, "dummy false message"), std::runtime_error); + EXPECT_THROW( + STD_TORCH_CHECK(false, "dummy ", "false ", "message"), + std::runtime_error); +} + +} // namespace aot_inductor +} // namespace torch diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 6125c21f0bedc..63e7821e9dfd4 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -33,6 +34,8 @@ Tensor sgd_out_of_place( const float weight_decay, const double lr, const bool maximize) { + STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); + int64_t *param_sizes; int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 49b2784e1df12..32c2d308d9d21 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -50,3 +50,6 @@ size # torch/headeronly/macros/Export.h C10_API + +# torch/headeronly/util/Exception.h +STD_TORCH_CHECK diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt index 08ad713ca8452..e42981d8804e7 100644 --- a/torch/headeronly/CMakeLists.txt +++ b/torch/headeronly/CMakeLists.txt @@ -21,6 +21,7 @@ configure_file( file(GLOB HEADERONLY_HEADERS *.h macros/*.h + util/*.h ) add_library(headeronly INTERFACE ${HEADERONLY_HEADERS}) diff --git a/torch/headeronly/macros/build.bzl b/torch/headeronly/macros/build.bzl index 9b136951ad139..00d31a40163cf 100644 --- a/torch/headeronly/macros/build.bzl +++ b/torch/headeronly/macros/build.bzl @@ -4,8 +4,8 @@ def define_targets(rules): srcs = [":cmake_macros_h"], hdrs = [ # Following the example from c10 - "Macros.h", "Export.h", + "Macros.h", ], linkstatic = True, local_defines = ["C10_BUILD_MAIN_LIB"], diff --git a/torch/headeronly/ovrsource_defs.bzl b/torch/headeronly/ovrsource_defs.bzl index 5ba9b593c2974..6d1051fed2e4a 100644 --- a/torch/headeronly/ovrsource_defs.bzl +++ b/torch/headeronly/ovrsource_defs.bzl @@ -30,6 +30,7 @@ def define_torch_headeronly_ovrsource(name, is_mobile): public_preprocessor_flags = pp_flags, public_raw_headers = native.glob([ "macros/*.h", + "util/*.h", ]), reexport_all_header_dependencies = False, visibility = [ diff --git a/torch/headeronly/util/Exception.h b/torch/headeronly/util/Exception.h new file mode 100644 index 0000000000000..c5d05e0fa9557 --- /dev/null +++ b/torch/headeronly/util/Exception.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { +// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases +// where the unlikely expression may be a constant, use this macro to ensure +// return statement analysis keeps working (at the cost of not getting the +// likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY +// in nvcc is causing us perf problems, this is not yet implemented, but this +// might be an interesting piece of C++ code for an intrepid bootcamper to +// write. +#if defined(__CUDACC__) +#define C10_UNLIKELY_OR_CONST(e) e +#else +#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) +#endif + +} // namespace c10 + +// STD_TORCH_CHECK throws std::runtime_error instead of c10::Error which is +// useful when certain headers are used in a libtorch-independent way, +// e.g. when Vectorized is used in AOTInductor generated code, or +// for custom ops to have an ABI stable dependency on libtorch. +#ifdef STRIP_ERROR_MESSAGES +#define STD_TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__)) +#else // so STRIP_ERROR_MESSAGES is not defined +namespace torch::headeronly::detail { +template +std::string stdTorchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + // This is similar to the one in c10/util/Exception.h, but does + // not depend on the more complex c10::str() function. ostringstream + // supports fewer data types than c10::str(), but should be sufficient + // in the headeronly world. + std::ostringstream oss; + ((oss << args), ...); + return oss.str(); +} + +inline const char* stdTorchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline const char* stdTorchCheckMsgImpl(const char* /*msg*/, const char* args) { + return args; +} +} // namespace torch::headeronly::detail + +#define STD_TORCH_CHECK_MSG(cond, type, ...) \ + (torch::headeronly::detail::stdTorchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#endif // STRIP_ERROR_MESSAGES + +#define STD_TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STD_TORCH_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + ##__VA_ARGS__)); \ + } From 036eb1f65dc6ed5e1e4b88a94e20afe6e3f356fe Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 18 Jul 2025 14:47:11 +0000 Subject: [PATCH 244/457] [precompile] Filter out ID_MATCH family of guards with caching_precompile. (#158368) Summary: For case like caching_precompile, we almost always want to drop ID_MATCH-type guards since they will block serialization. This diff add this behavior when this global flag is toggled on so that ID_MATCH guards are excluded from compilation and serialization. Test Plan: test_dynamo -- -k test_id_match_with_config Rollback Plan: Differential Revision: D78363609 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158368 Approved by: https://github.com/jamesjwu --- test/dynamo/test_guard_serialization.py | 17 +++++++++++++ torch/_dynamo/guards.py | 34 ++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 8e5f12894711e..10808c922b3fb 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -878,6 +878,23 @@ def fn(x): ): self._test_serialization("ID_MATCH", fn, torch.randn(3)) + @torch._dynamo.config.patch(caching_precompile=True) + def test_id_match_with_config(self): + def fn(x): + return x + id(x) + + ref, loaded = self._test_serialization("ID_MATCH", fn, torch.randn(3)) + self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True) + + def fn(x): + # usage of this context manager installs a FUNCTION_MATCH guard + with torch.no_grad(): + y = x * 2 + return y + + ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3)) + self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True) + def test_dispatch_key_set_match(self): def fn(x, dks): if dks.has("CPU"): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 983aa2133874c..8f55a666873ff 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1591,7 +1591,7 @@ def id_match_unchecked(self, guard: Guard): val = self.get(guard.name) id_val = self.id_ref(val, guard.name) code = f"___check_obj_id({ref}, {id_val})" - self._set_guard_export_info(guard, [code]) + self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard) @@ -2473,7 +2473,9 @@ def TENSOR_MATCH(self, guard: Guard, value=None): self._set_guard_export_info(guard, code) # A util that in the case of export, adds data onto guards - def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): + def _set_guard_export_info( + self, guard, code_list, provided_guarded_object=None, provided_func_name=None + ): # WARNING: It is important that cur_frame/caller do NOT stay in # the current frame, because they will keep things live longer # than they should. See TestMisc.test_release_module_memory @@ -2482,7 +2484,7 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None) caller = cur_frame.f_back del cur_frame assert caller is not None - func_name = caller.f_code.co_name + func_name = provided_func_name or caller.f_code.co_name del caller # We use func_name for export, so might as well get a nice defensive check out of it assert func_name in self.__class__.__dict__, ( @@ -2842,6 +2844,32 @@ def __init__( if not justknobs_check("pytorch/compiler:guard_nn_modules"): log.warning("guard_nn_modules is turned off using justknobs killswitch") + # TODO Be more explicit about the behavior for the users. + if ( + torch._dynamo.config.caching_precompile + and self.guards_serialization_mode != "load" + ): + _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) + + def guard_filter_fn(guards): + ret = [] + for keep, g in zip(_guard_filter_fn(guards), guards): + if not keep: + ret.append(False) + elif ( + g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE") + or "ID_MATCH" in g.derived_guard_types + ): + log.warning( + "%s guard on %s is dropped with caching_precompile=True.", + g.guard_type, + g.orig_guard.name, + ) + ret.append(False) + else: + ret.append(True) + return ret + sorted_guards = sorted(guards or (), key=Guard.sort_key) builder, guard_manager = self.build_guards( sorted_guards, From 193b29ee0c9db3573775ccfd226a4ac55d3ad80e Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 18 Jul 2025 10:34:55 -0500 Subject: [PATCH 245/457] [BE][EZ] Minor doc fixes (#158574) [BE] Minor doc fixes --- README.md | 2 +- functorch/README.md | 6 +++--- functorch/writing_batching_rules.md | 2 +- scripts/release_notes/README.md | 4 ++-- torch/csrc/jit/tensorexpr/ConditionalsInTE.md | 2 +- torch/distributed/CONTRIBUTING.md | 2 +- torch/fx/README.md | 8 ++++---- torch/onnx/README.md | 2 +- torch/utils/benchmark/README.md | 4 ++-- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 6d995f130e70b..24323032100d1 100644 --- a/README.md +++ b/README.md @@ -520,7 +520,7 @@ on [our website](https://pytorch.org/get-started/previous-versions). ## Getting Started -Three-pointers to get you started: +Three pointers to get you started: - [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/) - [Examples: easy to understand PyTorch code across all domains](https://github.com/pytorch/examples) - [The API Reference](https://pytorch.org/docs/) diff --git a/functorch/README.md b/functorch/README.md index 5021c8591cff3..5e16966b1daa9 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -7,7 +7,7 @@ | [**Future Plans**](#future-plans) **This library is currently under heavy development - if you have suggestions -on the API or use-cases you'd like to be covered, please open an github issue +on the API or use-cases you'd like to be covered, please open a GitHub issue or reach out. We'd love to hear about how you're using the library.** `functorch` is [JAX-like](https://github.com/google/jax) composable function @@ -161,7 +161,7 @@ result = vmap(model)(examples) ### grad -`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute +`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It computes the gradients of the output of func w.r.t. to `inputs[0]`. ```py @@ -192,7 +192,7 @@ def compute_loss(weights, example, target): weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) -inputs = (weights,examples, targets) +inputs = (weights, examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) ``` diff --git a/functorch/writing_batching_rules.md b/functorch/writing_batching_rules.md index 8643614acb555..61872c8d52327 100644 --- a/functorch/writing_batching_rules.md +++ b/functorch/writing_batching_rules.md @@ -5,7 +5,7 @@ First off, what are batching rules and why do we need so many of them? Well, to ### How does vmap work? Vmap is a function transform (pioneered by Jax) that allows one to batch functions. That is, given a function `f(x: [N]) -> [N]`, `vmap(f)` now transforms the signature to be `f(x: [B, N]) -> [B, N]`. That is - it adds a batch dimension to both the input and the output of the function. -This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this. +This guide will gloss over all the cool things you can do with this (there are many!), so let's focus on how we actually implement this. One misconception is that this is some magic compiler voodoo, or that it is inherently some function transform. It is not - and there's another framing of it that might make it more clear. diff --git a/scripts/release_notes/README.md b/scripts/release_notes/README.md index 6cd34da87b149..c88533f937e7d 100644 --- a/scripts/release_notes/README.md +++ b/scripts/release_notes/README.md @@ -130,7 +130,7 @@ This part is a little tedious but it seems to work. May want to explore using pa 5. Install the google doc extension [docs to markdown](https://github.com/evbacher/gd2md-html) 6. Start to compile back down these markdown files into a single markdown file. -`TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room hor improvement here. +`TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room for improvement here. ### Part 4: Cherry Picks @@ -187,7 +187,7 @@ You will then create a release at [Pytorch Release](https://github.com/pytorch/p #### Tidbits You will probably have a release note that doesn't fit into the character limit of github. I used the following regex: -`\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full lunks to (#). +`\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full links to (#). This will get formatted correctly in the github UI and can be checked when creating a draft release. diff --git a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md index 731f70a5b826d..c7bcea4976483 100644 --- a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md +++ b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md @@ -14,7 +14,7 @@ So far the recommendation was to standardize on fused conditionals. ## Expression Conditionals vs Statement Conditionals -Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expression important operators like ReLU: +Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expressing important operators like ReLU: ``` store (((load A) >= 0.0) ? (load A) : 0.0), B diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md index bcfe8df9abd48..913017a6cabfb 100644 --- a/torch/distributed/CONTRIBUTING.md +++ b/torch/distributed/CONTRIBUTING.md @@ -2,7 +2,7 @@ Please go through PyTorch's top level [Contributing Guide](../../CONTRIBUTING.md) before proceeding with this guide. -[PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We would highly recommend going through some of that material before you start working on PyTorch Distributed. +[PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We highly recommend going through some of that material before you start working on PyTorch Distributed. In this document, we mostly focus on some of the code structure for PyTorch distributed and implementation details. diff --git a/torch/fx/README.md b/torch/fx/README.md index 4c799da7bc402..3d42cb9375d43 100644 --- a/torch/fx/README.md +++ b/torch/fx/README.md @@ -70,7 +70,7 @@ Here, we set up a simple Module that exercises different language features: fetc The `fx.Graph` is a core data structure in FX that represents the operations and their dependencies in a structured format. It consists of a List of `fx.Node` representing individual operations and their inputs and outputs. The Graph enables simple manipulation and analysis of the model structure, which is essential for implementing various transformations and optimizations. ## Node -An `fx.Node` is a datastructure that represent individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. +An `fx.Node` is a data structure that represents individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. ## [GraphModule](https://pytorch.org/docs/main/fx.html#torch.fx.GraphModule) ## The `fx.GraphModule` is a subclass of `nn.Module` that holds the transformed Graph, the original module's parameter attributes and its source code. It serves as the primary output of FX transformations and can be used like any other `nn.Module`. `fx.GraphModule` allows for the execution of the transformed model, as it generates a valid forward method based on the Graph's structure. @@ -115,11 +115,11 @@ Tracing captures an intermediate representation (IR), which is represented as a Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care +- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is ignored. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. +- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are ignored - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* +- `call_method` calls a method on a value. `name` is similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. diff --git a/torch/onnx/README.md b/torch/onnx/README.md index c4691ea01802a..7c8596365f270 100644 --- a/torch/onnx/README.md +++ b/torch/onnx/README.md @@ -23,7 +23,7 @@ symbolic_opset9.py. To extend support for updated operators in different opset versions on top of opset 9, simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. -Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. +Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. ## Editing Symbolic Files diff --git a/torch/utils/benchmark/README.md b/torch/utils/benchmark/README.md index 4a64b778181f8..6fa025e51d37c 100644 --- a/torch/utils/benchmark/README.md +++ b/torch/utils/benchmark/README.md @@ -25,7 +25,7 @@ into two broad categories: * `Timer` implements the `blocked_autorange` function which is a mixture of `timeit.Timer.repeat` and `timeit.Timer.autorange`. This function - selects and appropriate number and runs for a roughly fixed amount of time + selects an appropriate number and runs for a roughly fixed amount of time (like `autorange`), but is less wasteful than `autorange` which discards ~75% of measurements. It runs many times, similar to `repeat`, and returns a `Measurement` containing all of the run results. @@ -46,7 +46,7 @@ table will be generated per unique label. may be logically equivalent differ in implementation. Assigning separate sub_labels will result in a row per sub_label. If a sublabel is not provided, `stmt` is used instead. Statistics (such as computing the fastest -implementation) are use all sub_labels. +implementation) use all sub_labels. * `description`: This describes the inputs. For instance, `stmt=torch.add(x, y)` can be run over several values of `x` and `y`. Each pair should be given its From 35df895d0564cc53dfcad829732fc6b3a9b7eb86 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 18 Jul 2025 15:55:24 +0000 Subject: [PATCH 246/457] [AOTI] package loader normalize path separator (#158630) Add `normalize_path_separator` to handle Windows path simplify. This solution is working well on `torch/_inductor/cpp_builder.py`: https://github.com/pytorch/pytorch/blob/a00cd8cf252a0b061f2eef6b5b42ae967acf5f64/torch/_inductor/cpp_builder.py#L406-L409 Let's copy it to package loader. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158630 Approved by: https://github.com/angelayi --- .../aoti_package/model_package_loader.cpp | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 8c674764d9dc4..946342aad4146 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,31 @@ namespace fs = std::filesystem; #endif namespace { + +const std::string k_separator = "/"; + +std::string normalize_path_separator(const std::string& orig_path) { + /* + On Windows and Linux have different separator: + On Windows use "\", and the path like: C:\Users\Test\file.txt + On Linux use "/", and the path like: /home/user/file.txt + + In order to simplify the path operation, we can use this function to + normalize path separator. It will convert Windows separator to Linux + separator, and reuse the common code to handle both Windows and Linux + path. + On Windows, when we input: "C:\Users\Test\file.txt", the output should be: + "C:/Users/Test/file.txt". And then, we can process the output like on Linux. + */ +#ifdef _WIN32 + std::string normalized_path = orig_path; + std::replace(normalized_path.begin(), normalized_path.end(), '\\', '/'); + return normalized_path; +#else + return orig_path; +#endif +} + bool file_exists(const std::string& path) { #ifdef _WIN32 return fs::exists(path); @@ -67,12 +93,6 @@ std::string create_temp_dir() { return temp_dir; #endif } - -#ifdef _WIN32 -const std::string k_separator = "\\"; -#else -const std::string k_separator = "/"; -#endif } // namespace namespace torch::inductor { @@ -92,11 +112,12 @@ const nlohmann::json& load_json_file(const std::string& json_path) { } std::tuple get_cpp_compile_command( - const std::string& filename, + const std::string& arg_filename, const std::vector& sources, const nlohmann::json& compile_options, const std::string& output_dir = "") { // Construct the cpp command + auto filename = normalize_path_separator(arg_filename); std::string compiler = compile_options["compiler"].get(); bool compile_only = compile_options["compile_only"].get(); @@ -156,7 +177,7 @@ std::tuple get_cpp_compile_command( std::string compile_only_arg = compile_only ? "-c" : ""; - std::string cmd = fmt::format( + std::string cmd = normalize_path_separator(fmt::format( "{} {} {} {} {} {} {} {} {} {} -o {}", compiler, source_args, @@ -168,7 +189,7 @@ std::tuple get_cpp_compile_command( libraries_args, libraries_dirs_args, compile_only_arg, - target_file); + target_file)); return std::make_tuple(cmd, target_file); } @@ -338,8 +359,6 @@ std::unordered_set find_model_names( // Escape the separator if it's backslash (needed for regex) std::string sep = k_separator; - if (sep == "\\") - sep = "\\\\"; std::string pattern = "data" + sep + "aotinductor" + sep + "([^" + sep + "]+)" + sep; @@ -412,7 +431,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( &zip_archive, i, filename_str.data(), filename_len)) { throw std::runtime_error("Failed to read filename"); } - found_filenames.push_back(filename_str); + found_filenames.push_back(normalize_path_separator(filename_str)); } if (found_filenames.empty()) { From 50f33a6fca88cd04b79760483e69a73b5eabe25e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 16:45:32 +0000 Subject: [PATCH 247/457] Revert "[DTensor] fix copy_ strategy (#158538)" This reverts commit 7b05bdd925f0f4b49e68662f9761fabaa27f2faf. Reverted https://github.com/pytorch/pytorch/pull/158538 on behalf of https://github.com/clee2000 due to broke lint? [GH job link](https://github.com/pytorch/pytorch/actions/runs/16361950974/job/46231492581) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/d8b084312b54e97bdbaf6a178fe2fc628a23243b) ([comment](https://github.com/pytorch/pytorch/pull/158490#issuecomment-3090042448)) --- test/distributed/tensor/test_tensor_ops.py | 44 +++++++-------- torch/distributed/tensor/_ops/_tensor_ops.py | 56 ++++++++++++++------ 2 files changed, 58 insertions(+), 42 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index d62da27d43393..9140d2f5aae13 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -56,11 +56,10 @@ def test_clone(self): @with_comms def test_copy_(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - - # basic test + src_specs = [[Replicate()], [Shard(0)]] src_tensor = torch.randn((12, 12)) + dst_tensor = torch.zeros(12, 12) - src_specs = [[Replicate()], [Shard(0)]] dst_specs = [[Replicate()], [Shard(0)]] for dst_spec, src_spec in zip(dst_specs, src_specs): src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) @@ -69,29 +68,22 @@ def test_copy_(self): dst_tensor.copy_(src_tensor) self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - # simple broadcasting - src_tensor = torch.randn((128,)) - dst_tensor = torch.zeros(128, 128) - src_specs = [[Replicate()], [Shard(0)]] - dst_specs = [[Replicate()], [Shard(1)]] - for dst_spec, src_spec in zip(dst_specs, src_specs): - src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) - dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) - dst_dtensor.copy_(src_dtensor) - dst_tensor.copy_(src_tensor) - self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - - # The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen - src_tensor = torch.randn((64, 1)) - dst_tensor = torch.zeros(16, 32, 64, 128) - src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]] - dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]] - for dst_spec, src_spec in zip(dst_specs, src_specs): - src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) - dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) - dst_dtensor.copy_(src_dtensor) - dst_tensor.copy_(src_tensor) - self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + # @pytest.mark.xfail + # @with_comms + # def test_copy_broadcast(self): + # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # src_specs = [[Replicate()], [Shard(0)]] + # src_tensor = torch.randn((12,)) + + # dst_tensor = torch.zeros(12, 12) + # dst_specs = [[Replicate()], [Shard(1)]] + # for dst_spec, src_spec in zip(dst_specs, src_specs): + # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case + # dst_dtensor.copy_(src_dtensor) + # dst_tensor.copy_(src_tensor) + # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) @with_comms def test_contiguous(self): diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index e53eef1610162..fd6621ab75124 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -35,8 +35,6 @@ Shard, ) -from ._pointwise_ops import pointwise_strategy - aten = torch.ops.aten @@ -93,20 +91,46 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) )(propagate_single_input_strategy) -# copy_ is actually a pointwise op with broadcasting, so reuse the pointwise strategy, which takes care of these -# requirements. -# -# Following torch broadcasting semantics (https://docs.pytorch.org/docs/stable/notes/broadcasting.html) -# - self can not change shape as a result of broadcasting since this is an inplace op -# - src can broadcast, but when it does it always does so from the trailing end -# e.g. the last dim of 'src' must match up with the last dim of 'self' -# -# DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input. -# In practice, what this means is -# - our output strategies should map 1:1 to our 'self' input strategies -# - our 'src' input may be redistributed to match up with the 'self' input, with the caveat of adjusting for -# broadcasting dim -register_op_strategy(aten.copy_.default)(pointwise_strategy) + +@register_op_strategy(aten.copy_.default) +def copy_strategy(op_schema: OpSchema) -> StrategyType: + # TODO: this strategy is incorrect for copy_ in the case that src tensor + # is smaller rank than self tensor. It is possible to select a strategy from self tensor + # that is invalid for dst tensor. + # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we + # may broadcast a new dim to the left or right of 0 when copying. + # + # For now, I just keep copy working essentially the way it was before this PR, + # but split it out so it can be handled separately in the future. + num_tensor_args = 2 + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + for _ in range(num_tensor_args) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + for _ in range(num_tensor_args) + ], + ) + for strategy in first_input_strategy.strategies + ] + ) @register_op_strategy( From bf4aa7827905a2fca96bf266b242a7a16e489af4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 16:45:32 +0000 Subject: [PATCH 248/457] Revert "[DTensor] Fix default_strategy and rename for clarity (#158490)" This reverts commit d8b084312b54e97bdbaf6a178fe2fc628a23243b. Reverted https://github.com/pytorch/pytorch/pull/158490 on behalf of https://github.com/clee2000 due to broke lint? [GH job link](https://github.com/pytorch/pytorch/actions/runs/16361950974/job/46231492581) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/d8b084312b54e97bdbaf6a178fe2fc628a23243b) ([comment](https://github.com/pytorch/pytorch/pull/158490#issuecomment-3090042448)) --- test/distributed/tensor/test_tensor_ops.py | 32 ------ torch/distributed/tensor/_ops/_tensor_ops.py | 109 ++++++------------- 2 files changed, 33 insertions(+), 108 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 9140d2f5aae13..9be582952f367 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -53,38 +53,6 @@ def test_clone(self): self.assertFalse(cloned_mat is mat) self.assertEqual(cloned_mat.to_local(), mat.to_local()) - @with_comms - def test_copy_(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - src_specs = [[Replicate()], [Shard(0)]] - src_tensor = torch.randn((12, 12)) - - dst_tensor = torch.zeros(12, 12) - dst_specs = [[Replicate()], [Shard(0)]] - for dst_spec, src_spec in zip(dst_specs, src_specs): - src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) - dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) - dst_dtensor.copy_(src_dtensor) - dst_tensor.copy_(src_tensor) - self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - - # @pytest.mark.xfail - # @with_comms - # def test_copy_broadcast(self): - # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - # src_specs = [[Replicate()], [Shard(0)]] - # src_tensor = torch.randn((12,)) - - # dst_tensor = torch.zeros(12, 12) - # dst_specs = [[Replicate()], [Shard(1)]] - # for dst_spec, src_spec in zip(dst_specs, src_specs): - # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) - # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) - # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case - # dst_dtensor.copy_(src_dtensor) - # dst_tensor.copy_(src_tensor) - # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - @with_comms def test_contiguous(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fd6621ab75124..9bdfc90d145d4 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -39,98 +39,55 @@ aten = torch.ops.aten -def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: - # For ops with a single tensor input, we perform a 1:1 mapping such that - # for each strategy that the input supports, we create a corresponding strategy. - # Note: this may be a complete waste of work, becuase it should be equivalent to - # `return first_input_strategy` (unless creating a deep copy is important for some reason) - assert len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) == 1, ( - "propagate_single_input_strategy only works for single-tensor-input ops" - ) - first_input_strategy = op_schema.args_schema[0] - assert isinstance(first_input_strategy, OpStrategy) - return OpStrategy( - [ - OpSpec( - output_specs=DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=[ - DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ) - ], - redistribute_cost=[ - generate_redistribute_costs( - first_input_strategy, strategy.output_spec - ) - ], +def default_strategy(op_schema: OpSchema) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = op_schema.args_schema[0] + assert isinstance(select_strategy, OpStrategy) + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + input_specs = [] + redistribute_cost = [] + for i in op_schema.args_schema: + input_specs.append( + DTensorSpec( + mesh=select_strategy.mesh, + placements=select_strategy.strategies[0].output_spec.placements, + tensor_meta=select_strategy.strategies[0].output_spec.tensor_meta, ) - for strategy in first_input_strategy.strategies - ] - ) + ) + redistribute_cost.append([0.0] * len(select_strategy.strategies)) + + default_strategy = [ + OpSpec( + output_specs=DTensorSpec( + mesh=select_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + for strategy in select_strategy.strategies + ] + return OpStrategy(default_strategy) register_op_strategy( [ aten.clone.default, aten.contiguous.default, + aten.copy_.default, aten.detach.default, aten.fill_.Scalar, aten.view.dtype, aten.zero_.default, ] -)(propagate_single_input_strategy) +)(default_strategy) register_op_strategy( aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(propagate_single_input_strategy) - - -@register_op_strategy(aten.copy_.default) -def copy_strategy(op_schema: OpSchema) -> StrategyType: - # TODO: this strategy is incorrect for copy_ in the case that src tensor - # is smaller rank than self tensor. It is possible to select a strategy from self tensor - # that is invalid for dst tensor. - # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we - # may broadcast a new dim to the left or right of 0 when copying. - # - # For now, I just keep copy working essentially the way it was before this PR, - # but split it out so it can be handled separately in the future. - num_tensor_args = 2 - first_input_strategy = op_schema.args_schema[0] - assert isinstance(first_input_strategy, OpStrategy) - return OpStrategy( - [ - OpSpec( - output_specs=DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=[ - DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ) - for _ in range(num_tensor_args) - ], - redistribute_cost=[ - generate_redistribute_costs( - first_input_strategy, strategy.output_spec - ) - for _ in range(num_tensor_args) - ], - ) - for strategy in first_input_strategy.strategies - ] - ) +)(default_strategy) @register_op_strategy( From acffd1a297ffacd120855401fd2e01af90cd4c81 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 18 Jul 2025 16:37:54 +0000 Subject: [PATCH 249/457] [iter] Update some of the tests to not call pickle (#156369) Some tests in test_iter only fail because of pickle. I'm skipping the pickle section as Dynamo doesn't support it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156369 Approved by: https://github.com/zou3519 --- test/dynamo/cpython/3_13/test_iter.diff | 22 +++++++++++++++---- test/dynamo/cpython/3_13/test_iter.py | 4 ---- ...thon313-test_iter-TestCase.test_iter_basic | 0 ...313-test_iter-TestCase.test_iter_big_range | 0 ...313-test_iter-TestCase.test_iter_class_for | 0 ...13-test_iter-TestCase.test_iter_class_iter | 0 ...ython313-test_iter-TestCase.test_iter_dict | 0 ...thon313-test_iter-TestCase.test_iter_empty | 0 ...n313-test_iter-TestCase.test_iter_for_loop | 0 ...thon313-test_iter-TestCase.test_iter_range | 0 ...hon313-test_iter-TestCase.test_iter_string | 0 ...thon313-test_iter-TestCase.test_iter_tuple | 0 12 files changed, 18 insertions(+), 8 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple diff --git a/test/dynamo/cpython/3_13/test_iter.diff b/test/dynamo/cpython/3_13/test_iter.diff index fd5545e6b2cf7..ee8a108ed3892 100644 --- a/test/dynamo/cpython/3_13/test_iter.diff +++ b/test/dynamo/cpython/3_13/test_iter.diff @@ -1,5 +1,5 @@ diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py -index 1b9f3cf7624..d2fc26ddc72 100644 +index 1b9f3cf7624..bad1ba94300 100644 --- a/test/dynamo/cpython/3_13/test_iter.py +++ b/test/dynamo/cpython/3_13/test_iter.py @@ -1,3 +1,60 @@ @@ -63,7 +63,7 @@ index 1b9f3cf7624..d2fc26ddc72 100644 # Test iterators. import sys -@@ -104,7 +161,7 @@ class EmptyIterClass: +@@ -104,12 +158,10 @@ class EmptyIterClass: # Main test suite @@ -72,7 +72,21 @@ index 1b9f3cf7624..d2fc26ddc72 100644 # Helper to check that an iterator returns a given sequence def check_iterator(self, it, seq, pickle=True): -@@ -635,6 +692,7 @@ class TestCase(unittest.TestCase): +- if pickle: +- self.check_pickle(it, seq) + res = [] + while 1: + try: +@@ -121,8 +173,6 @@ class TestCase(unittest.TestCase): + + # Helper to check that a for loop generates a given sequence + def check_for_loop(self, expr, seq, pickle=True): +- if pickle: +- self.check_pickle(iter(expr), seq) + res = [] + for val in expr: + res.append(val) +@@ -635,6 +685,7 @@ class TestCase(unittest.TestCase): pass # Test zip()'s use of iterators. @@ -80,7 +94,7 @@ index 1b9f3cf7624..d2fc26ddc72 100644 def test_builtin_zip(self): self.assertEqual(list(zip()), []) self.assertEqual(list(zip(*[])), []) -@@ -1187,4 +1245,4 @@ class TestCase(unittest.TestCase): +@@ -1187,4 +1238,4 @@ class TestCase(unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_iter.py b/test/dynamo/cpython/3_13/test_iter.py index d2fc26ddc7210..e752426cf5c0e 100644 --- a/test/dynamo/cpython/3_13/test_iter.py +++ b/test/dynamo/cpython/3_13/test_iter.py @@ -165,8 +165,6 @@ class TestCase(__TestCase): # Helper to check that an iterator returns a given sequence def check_iterator(self, it, seq, pickle=True): - if pickle: - self.check_pickle(it, seq) res = [] while 1: try: @@ -178,8 +176,6 @@ def check_iterator(self, it, seq, pickle=True): # Helper to check that a for loop generates a given sequence def check_for_loop(self, expr, seq, pickle=True): - if pickle: - self.check_pickle(iter(expr), seq) res = [] for val in expr: res.append(val) diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 6f73e067963e31d16840fbc34993a64cee698746 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 18 Jul 2025 16:37:54 +0000 Subject: [PATCH 250/457] [iter] exhaust `ListIterator` when `unpack_var_sequence` is called (#156370) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156370 Approved by: https://github.com/zou3519 ghstack dependencies: #156369 --- .../CPython313-test_iter-TestCase.test_sinkstate_dict | 0 .../CPython313-test_iter-TestCase.test_sinkstate_list | 0 .../CPython313-test_iter-TestCase.test_sinkstate_range | 0 .../CPython313-test_iter-TestCase.test_sinkstate_string | 0 .../CPython313-test_iter-TestCase.test_sinkstate_tuple | 0 torch/_dynamo/variables/lists.py | 7 ++++++- 6 files changed, 6 insertions(+), 1 deletion(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string delete mode 100644 test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 55891ce1de243..3e0a91b5e9224 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1279,8 +1279,13 @@ def as_python_constant(self): raise NotImplementedError return iter([x.as_python_constant() for x in self.items]) + def has_unpack_var_sequence(self, tx): + return True + def unpack_var_sequence(self, tx): - return list(self.items[self.index :]) + r = list(self.items[self.index :]) + self.index = len(self.items) + return r def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: return self.unpack_var_sequence(tx) From 8c3f84908b085a26a6d8c7a90ce7c94ab2fe6f0a Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 18 Jul 2025 17:18:10 +0000 Subject: [PATCH 251/457] [aot] fix greater_than_max build fail on Windows. (#158479) Error snapshot: image Reason: `std::numeric_limits::max` is confilct to windef.h:`max(a, b)` Fix code: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158479 Approved by: https://github.com/desertfire --- c10/util/TypeSafeSignMath.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 2853ff48d1831..58c0506783023 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -79,7 +79,7 @@ template inline constexpr bool greater_than_max(const T& x) { constexpr bool can_overflow = std::numeric_limits::digits > std::numeric_limits::digits; - return can_overflow && x > std::numeric_limits::max(); + return can_overflow && x > (std::numeric_limits::max)(); } #ifdef __GNUC__ From 725cdb218ec7b117b88baf5c6f4ac39c863a4b17 Mon Sep 17 00:00:00 2001 From: Deepak Seshadri Date: Fri, 18 Jul 2025 17:33:12 +0000 Subject: [PATCH 252/457] Name threads in caffe2/torch/distributed/checkpoint AsyncCheckpointExecutor (#158612) Differential Revision: D78493333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158612 Approved by: https://github.com/d4l3k --- torch/distributed/checkpoint/_async_thread_executor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/_async_thread_executor.py b/torch/distributed/checkpoint/_async_thread_executor.py index 1038c177529d2..3fad17b2dea98 100644 --- a/torch/distributed/checkpoint/_async_thread_executor.py +++ b/torch/distributed/checkpoint/_async_thread_executor.py @@ -37,7 +37,9 @@ def save_wrapper( class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): def __init__(self) -> None: - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="AsyncCheckpointExecutor" + ) def execute_save( self, From 86675af3f02e54fed4bbae68d6316274b93b373f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 17:46:11 +0000 Subject: [PATCH 253/457] Revert "[ROCm][CI] update fbgemm_gpu hash used by inductor tests (#158602)" This reverts commit 9308261a2afb69d807ea06508bb8582b066d9ccd. Reverted https://github.com/pytorch/pytorch/pull/158602 on behalf of https://github.com/ZainRizvi due to The lint job failure was hiding a real lint failure. See here for more details: [GH job link](https://github.com/pytorch/pytorch/actions/runs/16375911199/job/46275682191) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/6f73e067963e31d16840fbc34993a64cee698746) ([comment](https://github.com/pytorch/pytorch/pull/158602#issuecomment-3090209891)) --- .ci/pytorch/common_utils.sh | 27 +------------------------- .github/ci_commit_pins/fbgemm_rocm.txt | 2 +- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 3de68991bafce..3dbc2ece9e70b 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -176,43 +176,18 @@ function install_torchrec_and_fbgemm() { pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" pip_uninstall fbgemm-gpu-nightly - # Set ROCM_HOME isn't available, use ROCM_PATH if set or /opt/rocm - ROCM_HOME="${ROCM_HOME:-${ROCM_PATH:-/opt/rocm}}" - - # Find rocm_version.h header file for ROCm version extract - rocm_version_h="${ROCM_HOME}/include/rocm-core/rocm_version.h" - if [ ! -f "$rocm_version_h" ]; then - rocm_version_h="${ROCM_HOME}/include/rocm_version.h" - fi - - # Error out if rocm_version.h not found - if [ ! -f "$rocm_version_h" ]; then - echo "Error: rocm_version.h not found in expected locations." >&2 - exit 1 - fi - - # Extract major, minor and patch ROCm version numbers - MAJOR_VERSION=$(grep 'ROCM_VERSION_MAJOR' "$rocm_version_h" | awk '{print $3}') - MINOR_VERSION=$(grep 'ROCM_VERSION_MINOR' "$rocm_version_h" | awk '{print $3}') - PATCH_VERSION=$(grep 'ROCM_VERSION_PATCH' "$rocm_version_h" | awk '{print $3}') - ROCM_INT=$(($MAJOR_VERSION * 10000 + $MINOR_VERSION * 100 + $PATCH_VERSION)) - echo "ROCm version: $ROCM_INT" - export BUILD_ROCM_VERSION="$MAJOR_VERSION.$MINOR_VERSION" - pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm - pushd /tmp git clone --recursive https://github.com/pytorch/fbgemm pushd fbgemm/fbgemm_gpu git checkout "${fbgemm_commit}" python setup.py install \ - --build-variant=rocm \ + --package_variant=rocm \ -DHIP_ROOT_DIR="${ROCM_PATH}" \ -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" popd rm -rf fbgemm - popd else # See https://github.com/pytorch/pytorch/issues/106971 CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" diff --git a/.github/ci_commit_pins/fbgemm_rocm.txt b/.github/ci_commit_pins/fbgemm_rocm.txt index db140a31f3fa4..fa11e10ca6b8e 100644 --- a/.github/ci_commit_pins/fbgemm_rocm.txt +++ b/.github/ci_commit_pins/fbgemm_rocm.txt @@ -1 +1 @@ -7f1de94a4c2d14f59ad4ca84538c36084ea6b2c8 +5fb5024118e9bb9decf96c2b0b1a8f0010bf56be From b4358c5e8731c1035af8bd0d6260de9d239a3e5d Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Fri, 18 Jul 2025 18:00:46 +0000 Subject: [PATCH 254/457] [inductor] Explicitly link c10 in inductor. (#158622) MSVC have error "unresolved external symbol" when compiling inductor. Explicitly link c10 in inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158622 Approved by: https://github.com/desertfire Co-authored-by: Xu Han --- torch/_inductor/cpp_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 975045c529ded..b8cdc50368e7b 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1320,6 +1320,8 @@ def get_cpp_torch_device_options( include_dirs = cpp_extension.include_paths(device_type) libraries_dirs = cpp_extension.library_paths(device_type) + if not config.is_fbcode(): + libraries += ["c10"] if device_type == "cuda": definitions.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") From 6e07d6a0ff386d99d8c2f1d25978b0683988a4cb Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 18 Jul 2025 18:15:51 +0000 Subject: [PATCH 255/457] [Dynamo][Better Engineering] Add typing support for _dynamo/repro and debug_utils (#158504) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to an important set of utilities in dynamo, `repro/` and the base `debug_utils.py` Running ``` mypy torch/_dynamo/repro/ torch/_dynamo/debug_utils.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 905 | 3268 | 27.69% | 22 | 81 | 27.16% | | This PR | 3368 | 3368 | 100.00% | 81 | 81 | 100.00% | | Delta | +2463 | +100 | +72.31% | +59 | 0 | +72.84% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158504 Approved by: https://github.com/mlazos --- torch/_dynamo/debug_utils.py | 180 +++++++++++++++++----------- torch/_dynamo/repro/after_aot.py | 161 ++++++++++++++----------- torch/_dynamo/repro/after_dynamo.py | 134 +++++++++++++-------- torch/_dynamo/repro/aoti.py | 120 +++++++++++-------- 4 files changed, 349 insertions(+), 246 deletions(-) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index a23b58cedf226..e084222c21715 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code="method-assign" - """ Debug utilities for TorchDynamo compilation and execution. @@ -34,6 +31,7 @@ import tempfile import textwrap from collections import Counter +from collections.abc import Sequence from importlib import import_module from typing import Any, Callable, Optional, TypeVar @@ -43,7 +41,9 @@ from torch import Tensor from torch._dynamo.testing import rand_strided from torch._prims_common import is_float_dtype +from torch.hub import tqdm from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage from torch.utils._content_store import ContentStoreReader, ContentStoreWriter from . import config @@ -64,6 +64,7 @@ extra_deps = [] extra_imports = "" +cur_target = "" if use_buck: extra_deps = [ "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", @@ -79,7 +80,7 @@ class BuckTargetWriter: - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) self.target = self.py_file.replace(".py", "") @@ -93,7 +94,7 @@ def __init__(self, filename): tmp = tmp[tmp.find("fbcode/") :][7:] self.cmd_line_path = f"//{tmp}:{self.target}" - def build(self): + def build(self) -> str: extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) return textwrap.dedent( f""" @@ -119,7 +120,7 @@ def build(self): """ ) - def write(self, print_msg=True): + def write(self, print_msg: bool = True) -> list[str]: target_file = os.path.join(self.subdir, "TARGETS") with open(target_file, "w") as fd: fd.write(self.build()) @@ -133,7 +134,7 @@ def write(self, print_msg=True): return cmd_split -def minifier_dir(): +def minifier_dir() -> str: path = os.path.join(get_debug_dir(), "minifier") if path is None: path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" @@ -171,7 +172,7 @@ class NNModuleToString: ] @staticmethod - def can_convert_to_string(gm): + def can_convert_to_string(gm: torch.fx.GraphModule) -> bool: cant_convert = set() for _, module in gm.named_children(): if type(module) not in NNModuleToString.safe_reprs: @@ -183,7 +184,7 @@ def can_convert_to_string(gm): return True @staticmethod - def convert(gm): + def convert(gm: torch.fx.GraphModule) -> str: from torch.nn.modules.module import _addindent tab = " " * 4 @@ -248,7 +249,7 @@ def __init__(self) -> None: @functools.cache # subprocess is expensive -def _cuda_system_info_comment(): +def _cuda_system_info_comment() -> str: if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" @@ -272,7 +273,7 @@ def _cuda_system_info_comment(): return model_str -def generate_env_vars_string(*, stable_output=False): +def generate_env_vars_string(*, stable_output: bool = False) -> str: """ Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton. """ @@ -282,7 +283,7 @@ def generate_env_vars_string(*, stable_output=False): allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"] skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"] - def filter(key): + def filter(key: str) -> bool: return any(string in key for string in allow_list) and key not in skip_list config_lines = [ @@ -297,7 +298,7 @@ def filter(key): """ -def generate_config_string(*, stable_output=False): +def generate_config_string(*, stable_output: bool = False) -> str: import torch._functorch.config import torch._inductor.config @@ -317,11 +318,11 @@ def generate_config_string(*, stable_output=False): """ -def get_minifier_repro_path(): +def get_minifier_repro_path() -> str: return os.path.join(minifier_dir(), "minifier_launcher.py") -def helper_for_dump_minify(contents): +def helper_for_dump_minify(contents: str) -> None: minified_repro_path = get_minifier_repro_path() log.warning("Writing minified repro to:\n%s", minified_repro_path) @@ -340,7 +341,7 @@ class AccuracyError(Exception): pass -def clone_inputs_retaining_gradness(example_inputs): +def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]: """ This clone inputs is different from utils clone_input. In case of minifier, all the tensors are leaf tensors while creating a new graph. So, we set the @@ -350,10 +351,15 @@ def clone_inputs_retaining_gradness(example_inputs): for idx in range(len(example_inputs)): if isinstance(cloned_inputs[idx], torch.Tensor): cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) - return cloned_inputs + return cloned_inputs # type: ignore[return-value] -def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): +def run_fwd_maybe_bwd( + gm: torch.fx.GraphModule, + args: Sequence[Any], + only_fwd: bool = False, + disable_clone: bool = False, +) -> Any: """ Runs a forward and possibly backward iteration for a given mod and args. @@ -381,14 +387,14 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): def same_two_models( - gm, - opt_gm, - example_inputs, - only_fwd=False, + gm: torch.fx.GraphModule, + opt_gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + only_fwd: bool = False, *, - require_fp64=False, - ignore_non_fp=False, -): + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: """ Check two models have same accuracy. @@ -438,7 +444,7 @@ def same_two_models( return passing -def cast_dtype_args_to_fp64(model): +def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in model.graph.nodes: if ( node.op == "call_function" @@ -459,7 +465,9 @@ def cast_dtype_args_to_fp64(model): return model -def cast_to(dtype, model, inputs): +def cast_to( + dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any] +) -> tuple[torch.fx.GraphModule, list[Any]]: from torch.utils._pytree import tree_map model = model.to(dtype) @@ -477,19 +485,21 @@ def cast_to(dtype, model, inputs): return model, inputs -def cast_to_fp64(model, inputs): +def cast_to_fp64( + model: torch.fx.GraphModule, inputs: list[Any] +) -> tuple[torch.fx.GraphModule, list[Any]]: return cast_to(torch.float64, model, inputs) def backend_accuracy_fails( - gm, - example_inputs, - compiler_fn, - only_fwd=False, + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], + only_fwd: bool = False, *, - require_fp64=False, - ignore_non_fp=False, -): + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: try: compiled_gm = compiler_fn( copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) @@ -545,20 +555,27 @@ class NopInputReader: def __init__(self) -> None: self.total = 0 - def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + def storage( + self, + storage_hash: Optional[str], + nbytes: int, + *, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> None: self.total += 1 - def tensor(self, *args, **kwargs): + def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]: pass - def symint(self, *args, **kwargs): + def symint(self, *args: Any, **kwargs: Any) -> Optional[int]: pass # TODO: Support bundling the entire repro into a zip file for ease of # transferring around class InputReader: - def __init__(self, save_dir=None, *, pbar=None): + def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = None): # If None, we will generate random data instead. It's important # to natively support this use case as it will allow people to # share repros without including the real data, if the problem @@ -566,13 +583,20 @@ def __init__(self, save_dir=None, *, pbar=None): if save_dir is None: log.warning("no save_dir specified, will generate random data") self.store = ContentStoreReader(save_dir) if save_dir is not None else None - self.args = [] + self.args: list[Any] = [] self.pbar = pbar - def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + def storage( + self, + storage_hash: Optional[str], + nbytes: int, + *, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> UntypedStorage: if self.pbar is not None: self.pbar.update(1) - device = _device_or_default(device) + device = _device_or_default(device) # type: ignore[arg-type] dtype_hint = _dtype_or_default(dtype_hint) if self.store is not None and storage_hash is not None: try: @@ -593,16 +617,16 @@ def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): def tensor( self, - storage, - shape, - stride=None, + storage: UntypedStorage, + shape: "torch._prims_common.ShapeType", + stride: Optional["torch._prims_common.StrideType"] = None, *, - storage_offset=None, - dtype=None, - requires_grad=None, - is_leaf=None, - **metadata, - ): + storage_offset: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: Optional[bool] = None, + is_leaf: Optional[bool] = None, + **metadata: Any, + ) -> torch.Tensor: stride = _stride_or_default(stride, shape=shape) storage_offset = _storage_offset_or_default(storage_offset) dtype = _dtype_or_default(dtype) @@ -624,7 +648,7 @@ def tensor( self.args.append(t) return t # for BC - def symint(self, val): + def symint(self, val: Any) -> Any: self.args.append(val) return val # for BC @@ -642,8 +666,8 @@ def symint(self, val): class InputWriter: - def __init__(self, save_dir, *, stable_hash=False): - self._lines = [] + def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None: + self._lines: list[str] = [] # TODO: consider ensuring tensor and storage counters line up? self.storage_counter = itertools.count() self.save_dir = save_dir @@ -652,9 +676,9 @@ def __init__(self, save_dir, *, stable_hash=False): if save_dir is not None else None ) - self.seen_storages = {} + self.seen_storages: dict[StorageWeakRef, str] = {} - def lines(self): + def lines(self) -> list[str]: r = [ "def load_args(reader):", ] @@ -669,7 +693,13 @@ def lines(self): # of initialization may be appropriate # # If we had a FakeTensor, device_hint tells us what device should be - def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: + def storage( + self, + untyped_storage: UntypedStorage, + *, + device_hint: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> str: ws = StorageWeakRef(untyped_storage) v = self.seen_storages.get(ws) if v is not None: @@ -684,7 +714,7 @@ def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: device = untyped_storage.device if device.type == "meta": assert device_hint is not None - device = device_hint + device = device_hint # type: ignore[assignment] if _device_or_default(None) != device: maybe_device = f", device={device!r}" nbytes = untyped_storage.nbytes() @@ -697,7 +727,7 @@ def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: self.seen_storages[ws] = v return v - def tensor(self, name, t) -> None: + def tensor(self, name: str, t: torch.Tensor) -> None: from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq storage = self.storage( @@ -729,7 +759,7 @@ def tensor(self, name, t) -> None: + f") # {name}" ) - def unsupported(self, name, arg): + def unsupported(self, name: str, arg: Any) -> None: # NB: Try hard not to /print/ a tensor, that will be very slow self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}") # Best effort dump as much useful stuff we can lol, in case you want @@ -747,13 +777,13 @@ def unsupported(self, name, arg): self._lines.append('"""') # write out that the arg was filtered out as it is constant - def const(self, name) -> None: + def const(self, name: str) -> None: self._lines.append( f"reader.const({name!r}) # {name}, filtered out during compilation" ) # TODO: this doesn't actually symint atm - def symint(self, name, val) -> None: + def symint(self, name: str, val: Any) -> None: if isinstance(val, torch.SymInt): val = val.node.hint self._lines.append(f"reader.symint({val!r}) # {name}") @@ -782,8 +812,10 @@ def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "S from torch.utils._dtype_abbrs import dtype_abbrs - dtype_map = {value: key for key, value in dtype_abbrs.items()} - dtype_pattern = "|".join(dtype_abbrs.values()) + dtype_map: dict[str, torch.dtype] = { + value: key for key, value in dtype_abbrs.items() + } + dtype_pattern: str = "|".join(dtype_abbrs.values()) # Extracting the source code from the function source = inspect.getsource(func) @@ -799,21 +831,23 @@ class TensorContainer: # Dictionary for tensors from annotations kwargs: dict[str, Any] = {} - sym_shapes = sym_shapes or {} + sym_shapes_dict: dict[str, int] = sym_shapes or {} - def get_sym_int(symint): + def get_sym_int(symint: str) -> int: torch._check( - symint in sym_shapes or default_sym_shape is not None, + symint in sym_shapes_dict or default_sym_shape is not None, lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", ) - return sym_shapes.get(symint, default_sym_shape) + return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value] - def gen_tensor(shape, dtype) -> Tensor: + def gen_tensor( + shape: "torch._prims_common.ShapeType", dtype: torch.dtype + ) -> Tensor: # Resolve symbolic shapes to concrete values resolved_shape = [] dynamic_dims = [] for i, dim in enumerate(shape): - dim = dim.strip() + dim = dim.strip() # type: ignore[attr-defined] if "s" in dim: s = get_sym_int(dim) resolved_shape.append(s) @@ -868,9 +902,9 @@ def profile_to_file(filename: str) -> Callable[[T], T]: prof = cProfile.Profile() filename = os.path.abspath(os.path.expanduser(filename)) - def decorator(fn): + def decorator(fn: Any) -> Any: @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: prof.enable() try: return fn(*args, **kwargs) @@ -879,7 +913,7 @@ def wrapper(*args, **kwargs): return wrapper - def save_it(): + def save_it() -> None: prof.dump_stats(filename) sys.stderr.write( textwrap.dedent( diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index b1327473c9094..cdbc1fcda0371 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. @@ -33,7 +31,7 @@ from collections.abc import Sequence from importlib import import_module from tempfile import TemporaryFile -from typing import Any, Callable, TYPE_CHECKING, Union +from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack import torch @@ -157,7 +155,7 @@ def deferred_for_real_inputs( with config.patch(repro_after=None): return inner_debug_fn(real_inputs) - def inner_debug_fn(real_inputs): + def inner_debug_fn(real_inputs: Sequence["InputType"]) -> Any: """ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, example_inputs can be fake tensors. We can call compiler_fn (which is @@ -186,7 +184,7 @@ def inner_debug_fn(real_inputs): ) failed = not same_two_models( gm, - inner_compiled_fn, + inner_compiled_fn, # type: ignore[arg-type] real_inputs, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -250,7 +248,7 @@ def inner_debug_fn(real_inputs): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def maybe_fbcode_instructions(): +def maybe_fbcode_instructions() -> str: if is_fbcode(): extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) if len(extra_deps_formatted) > 0: @@ -283,14 +281,14 @@ def maybe_fbcode_instructions(): def generate_compiler_repro_string( - gm, - args, + gm: torch.fx.GraphModule, + args: Sequence[Any], *, - stable_output=False, - save_dir=None, - stable_hash=False, - has_distributed_ops=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + stable_hash: bool = False, + has_distributed_ops: bool = False, +) -> str: # Add distributed imports if needed distributed_imports = "" if has_distributed_ops: @@ -377,19 +375,19 @@ def generate_compiler_repro_string( def save_graph_repro( - fd, - gm, - args, - compiler_name, + fd: IO[Any], + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, *, - stable_output=False, - save_dir=None, - command="run", - accuracy=None, - tracing_mode=None, - check_str=None, - stable_hash=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + stable_hash: bool = False, +) -> None: if any( isinstance(arg, torch.fx.experimental._backward_state.BackwardState) for arg in args @@ -456,7 +454,13 @@ def save_graph_repro( fd.write("\n dist.destroy_process_group()\n") -def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + accuracy: Optional[Union[str, bool]] = None, +) -> None: subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -484,7 +488,9 @@ def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def dump_to_minify(gm, args, compiler_name: str): +def dump_to_minify( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str +) -> None: out = io.StringIO() # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") @@ -495,15 +501,15 @@ def dump_to_minify(gm, args, compiler_name: str): def isolate_fails( - fx_g, - args, + fx_g: torch.fx.GraphModule, + args: Sequence[Any], compiler_name: str, - env=None, - save_dir=None, - accuracy=None, - tracing_mode=None, - check_str=None, -): + env: Optional[dict[str, Any]] = None, + save_dir: Optional[str] = None, + accuracy: Optional[Union[bool, str]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, +) -> bool: if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -559,14 +565,16 @@ def isolate_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def inductor_fails(fx_g, args, check_str=None): +def inductor_fails( + fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None +) -> bool: has_cuda = False for arg in args: if isinstance(arg, torch.Tensor) and arg.is_cuda: has_cuda = True break - def sync(): + def sync() -> None: if has_cuda: # Ensures that segfaults are surfaced torch.cuda.synchronize() @@ -596,14 +604,19 @@ def sync(): def inductor_accuracy_fails( - fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False -): + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + check_str: Optional[str] = None, + *, + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: from torch._inductor.compile_fx import compile_fx_inner return backend_aot_accuracy_fails( fx_g, - args, - compile_fx_inner, + args, # type: ignore[arg-type] + compile_fx_inner, # type: ignore[arg-type] require_fp64=require_fp64, ignore_non_fp=ignore_non_fp, ) @@ -617,7 +630,9 @@ def inductor_accuracy_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def repro_common(options, mod, load_args): +def repro_common( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, Sequence[Any]]: # Invariant for graphs we generate with the repro script assert not any(mod.named_parameters()) for n, b in mod.named_buffers(): @@ -660,7 +675,7 @@ def repro_common(options, mod, load_args): return mod, args -ACCURACY_FAILS: dict[str, Callable[[nn.Module, Any], bool]] = { +ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { "": inductor_fails, # This might look inverted but it's not. strict_accuracy means "we will # minify any time we see anything that diverges", whereas accuracy is more @@ -673,7 +688,7 @@ def repro_common(options, mod, load_args): } -def repro_minifier_query(options, mod, load_args): +def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( ACCURACY_FAILS[options.accuracy], @@ -685,7 +700,7 @@ def repro_minifier_query(options, mod, load_args): sys.exit(0) -def repro_minify(options, mod, load_args): +def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: from functorch.compile import minifier mod, args = repro_common(options, mod, load_args) @@ -722,7 +737,7 @@ def repro_minify(options, mod, load_args): ) -def repro_analyze(options, mod, load_args): +def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.hooks import intermediate_hook @@ -740,7 +755,7 @@ def repro_analyze(options, mod, load_args): known_names = set() - def save_hook(name, val): + def save_hook(name: str, val: Any) -> None: known_names.add(name) if not options.skip_saving_inductor_intermediates: writer.write_tensor(os.path.join("inductor", name), val) @@ -757,10 +772,10 @@ def save_hook(name, val): tqdm(desc="Saving inductor intermediates", total=total) as pbar, ): assert not isinstance(compiled, str) - compiled(new_args) + compiled(new_args) # type: ignore[arg-type] assert not new_args - def compare_tuples(tuple1, tuple2): + def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] @@ -769,7 +784,7 @@ def compare_tuples(tuple1, tuple2): else: return " and ".join(f"{a} != {b}" for a, b in diff_values) - def check_hook(name, val): + def check_hook(name: str, val: Any) -> None: meta = writer.compute_tensor_metadata(val) meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) reason = compare_tuples(meta, meta2) @@ -783,15 +798,15 @@ def check_hook(name, val): intermediate_hook(check_hook), tqdm(desc="Checking inductor determinism", total=total) as pbar, ): - compiled(new_args) + compiled(new_args) # type: ignore[arg-type] assert not new_args class WriterInterp(fx.Interpreter): - def __init__(self, mod, subdir) -> None: + def __init__(self, mod: torch.nn.Module, subdir: str) -> None: super().__init__(mod) self.subdir = subdir - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -802,13 +817,13 @@ def run_node(self, n): # NB: the module cast doesn't actually do anything, since there are no # parameters/buffers on the module if not options.skip_saving_float64_intermediates: - new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] with tqdm(desc="Saving float64 intermediates", total=total) as pbar: WriterInterp(new_mod, "float64").boxed_run(new_args) assert not new_args class ExactReaderInterp(fx.Interpreter): - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -823,7 +838,7 @@ def run_node(self, n): # TODO: check eager determinism if not options.skip_check_deterministic: - new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] with tqdm(desc="Checking float64 determinism", total=total) as pbar: ExactReaderInterp(new_mod).boxed_run(new_args) assert not new_args @@ -831,7 +846,7 @@ def run_node(self, n): # Now that we've saved everything, interp through the eager graph # and do comparisons class ReaderInterp(fx.Interpreter): - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -839,7 +854,7 @@ def run_node(self, n): float64 = reader.read_tensor(os.path.join("float64", name)) logged = False - def log_error(msg, *args): + def log_error(msg: str, *args: Any) -> None: nonlocal logged logged = True pbar.write(f"DIVERGED at {name}: {msg % args}") @@ -861,12 +876,14 @@ def log_error(msg, *args): assert not args -def repro_get_args(options, mod, load_args): +def repro_get_args( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, list[Any]]: mod, args = repro_common(options, mod, load_args) - return mod, args + return mod, args # type: ignore[return-value] -def repro_run(options, mod, load_args): +def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: from torch._inductor.compile_fx import compile_fx_inner mod, args = repro_common(options, mod, load_args) @@ -881,7 +898,7 @@ def repro_run(options, mod, load_args): # seems counterintuitive if not same_two_models( mod, - compiled, + compiled, # type: ignore[arg-type] args, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -903,17 +920,17 @@ def repro_run(options, mod, load_args): # TODO: lazily load the inputs or something, rather than cloning them def run_repro( - mod, - load_args, + mod: nn.Module, + load_args: Any, *, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - tracing_mode=None, - patch_code=None, - check_str=None, - **kwargs, -): + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + patch_code: Optional[str] = None, + check_str: Optional[str] = None, + **kwargs: Any, +) -> Any: for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -946,7 +963,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 86a33677eb14d..898946d6f89f5 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for reproducing and debugging issues in Dynamo after graph capture. @@ -26,12 +24,12 @@ import shutil import sys import textwrap +from collections.abc import Sequence from importlib import import_module -from typing import Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.fx as fx -from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -53,7 +51,7 @@ from torch.hub import tqdm from .. import config -from ..backends.registry import lookup_backend, register_debug_backend +from ..backends.registry import CompilerFn, lookup_backend, register_debug_backend from ..debug_utils import clone_inputs_retaining_gradness @@ -68,7 +66,11 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def _accuracy_fails(gm, example_inputs, compiler_fn): +def _accuracy_fails( + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], +) -> bool: return backend_accuracy_fails( gm, example_inputs, @@ -79,29 +81,33 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): class WrapBackendDebug: - def __init__(self, unconfigured_compiler_fn, compiler_name: Optional[str]) -> None: + def __init__( + self, unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] + ) -> None: functools.wraps(unconfigured_compiler_fn)(self) - self._torchdynamo_orig_backend = unconfigured_compiler_fn # type: ignore[attr-defined] + self._torchdynamo_orig_backend = unconfigured_compiler_fn self._compiler_name = compiler_name if hasattr(unconfigured_compiler_fn, "__name__"): self.__name__ = unconfigured_compiler_fn.__name__ if hasattr(unconfigured_compiler_fn, "compiler_name"): - self.__name__ = unconfigured_compiler_fn.compiler_name + self.__name__ = unconfigured_compiler_fn.compiler_name # type: ignore[attr-defined] if hasattr(unconfigured_compiler_fn, "get_compiler_config"): self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] - def __call__(self, gm, example_inputs, **kwargs): + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[Any], **kwargs: Any + ) -> torch.fx.GraphModule: compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": - def add_paths(exc): - exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + def add_paths(exc: Exception) -> None: + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") # type: ignore[attr-defined] if use_buck: - exc.buck_command = " ".join( + exc.buck_command = " ".join( # type: ignore[attr-defined] BUCK_CMD_PREFIX - + [BuckTargetWriter(exc.minifier_path).cmd_line_path] + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] # type: ignore[attr-defined] ) if config.repro_level == 3: @@ -111,7 +117,7 @@ def add_paths(exc): if config.repro_level == 4: # Check Accuracy compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) - if _accuracy_fails(gm, example_inputs, compiler_fn): + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] log.warning( "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." ) @@ -126,7 +132,7 @@ def add_paths(exc): else: try: compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) - run_fwd_maybe_bwd(compiled_gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] except Exception as exc: log.warning( "Compiled Fx GraphModule failed. Creating script to minify the error." @@ -149,10 +155,12 @@ def add_paths(exc): else: compiled_gm = compiler_fn(gm, example_inputs) - return compiled_gm + return compiled_gm # type: ignore[return-value] -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: Optional[str]): +def wrap_backend_debug( + unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] +) -> WrapBackendDebug: """ A minifier decorator that wraps the TorchDynamo produced Fx graph modules. As opposed to wrap_compiler_debug, this wrapper intercepts at the @@ -170,15 +178,15 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: Optional[str]): def generate_dynamo_fx_repro_string( - gm, - args, - compiler_name, - check_accuracy=False, + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, *, - stable_output=False, - save_dir=None, - command="run", -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", +) -> str: """ Generate a repro string for backend-agnostic minified version. """ @@ -225,7 +233,12 @@ def generate_dynamo_fx_repro_string( ) -def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): +def dump_backend_repro_as_file( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: """ Saves the repro to a repro.py file """ @@ -253,7 +266,12 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): shutil.copyfile(file_name, latest_repro) -def dump_backend_state(gm, args, compiler_name, check_accuracy=False): +def dump_backend_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: """ Dumps the dynamo graph to repro the issue. 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a @@ -271,7 +289,9 @@ def dump_backend_state(gm, args, compiler_name, check_accuracy=False): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def dump_to_minify_after_dynamo(gm, args, compiler_name): +def dump_to_minify_after_dynamo( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str] +) -> None: # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): @@ -295,8 +315,8 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): @register_debug_backend # type: ignore[arg-type] def dynamo_minifier_backend( - gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn -): + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -336,7 +356,9 @@ def dynamo_minifier_backend( @register_debug_backend # type: ignore[arg-type] -def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): +def dynamo_accuracy_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -366,7 +388,12 @@ def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): return gm -def backend_fails(gm, example_inputs, compiler_fn, orig_failure): +def backend_fails( + gm: fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: CompilerFn, + orig_failure: Sequence[Any], +) -> bool: """ Minifier uses this function to identify if the minified graph module fails with the same error. @@ -383,8 +410,8 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure): try: # Run the original gm to check eager validity run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) - compiled_gm = compiler_fn(gm, example_inputs) - run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) # type: ignore[arg-type] + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) # type: ignore[arg-type] except Exception as e: new_failure = str(e) if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: @@ -397,7 +424,7 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def run_load_args(options, mod, load_args): +def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[Any]: if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -423,7 +450,7 @@ def run_load_args(options, mod, load_args): return args -def repro_minify(options, mod, load_args): +def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: args = run_load_args(options, mod, load_args) # Setup debug minifier compiler @@ -450,7 +477,7 @@ def repro_minify(options, mod, load_args): opt_mod(*args) -def repro_run(options, mod, load_args): +def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: opt_mod = torch._dynamo.optimize(options.backend)(mod) if options.accuracy != "": @@ -460,10 +487,10 @@ def repro_run(options, mod, load_args): with torch.amp.autocast("cuda", enabled=options.autocast): # TODO: disable clone args = run_load_args(options, mod, load_args) - assert same_two_models(mod, mod, args), "Eager itself failed" + assert same_two_models(mod, mod, args), "Eager itself failed" # type: ignore[arg-type] if not same_two_models( - mod, - opt_mod, + mod, # type: ignore[arg-type] + opt_mod, # type: ignore[arg-type] args, only_fwd=config.repro_forward_only, ignore_non_fp=config.repro_ignore_non_fp, @@ -472,26 +499,29 @@ def repro_run(options, mod, load_args): else: with torch.amp.autocast("cuda", enabled=options.autocast): args = run_load_args(options, mod, load_args) - run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) # type: ignore[arg-type] del args args = run_load_args(options, mod, load_args) run_fwd_maybe_bwd( - opt_mod, args, only_fwd=options.only_fwd, disable_clone=True + opt_mod, # type: ignore[arg-type] + args, + only_fwd=options.only_fwd, + disable_clone=True, # type: ignore[arg-type] ) def run_repro( - mod, - load_args, + mod: torch.nn.Module, + load_args: Any, *, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - autocast=False, - backend="inductor", - **kwargs, -): + save_dir: Optional[str] = None, + autocast: bool = False, + backend: str = "inductor", + **kwargs: Any, +) -> None: for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -517,7 +547,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/aoti.py b/torch/_dynamo/repro/aoti.py index c3fab6bd086a1..808383e68e51a 100644 --- a/torch/_dynamo/repro/aoti.py +++ b/torch/_dynamo/repro/aoti.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. @@ -26,8 +24,9 @@ import shutil import sys import textwrap +from collections.abc import Sequence from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, IO, Optional, Union import torch from torch._dynamo.debug_utils import ( @@ -54,7 +53,7 @@ class AOTIMinifierError(Exception): - def __init__(self, original_exception): + def __init__(self, original_exception: Union[str, Exception]) -> None: additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" full_message = f"{additional_message}: {str(original_exception)}" super().__init__(full_message) @@ -66,7 +65,7 @@ def dump_to_minify( compiler_name: str, command: str = "minify", options: Optional[dict[str, Any]] = None, -): +) -> None: """ If command is "minify": Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. @@ -111,8 +110,8 @@ def dump_to_minify( log.warning("No write permissions for %s", file_name) -def get_module_string(gm): - def _convert_to_comment(s_): +def get_module_string(gm: torch.fx.GraphModule) -> str: + def _convert_to_comment(s_: str) -> str: s = s_.split("\n") if len(s) == 1: return "# " + s_ @@ -132,21 +131,21 @@ def _convert_to_comment(s_): def save_graph_repro_ep( - fd, - compiler_name, + fd: IO[Any], + compiler_name: str, *, exported_program: Optional[ExportedProgram] = None, gm: Optional[torch.nn.Module] = None, args: Optional[tuple[Any]] = None, config_patches: Optional[dict[str, str]] = None, - stable_output=False, - save_dir=None, - command="run", - accuracy=None, - check_str=None, - module_in_comment=False, - strict=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + check_str: Optional[str] = None, + module_in_comment: bool = False, + strict: bool = False, +) -> None: # Save graph for reproducing the error. # Either exported_program or gm will be saved, depending on which one is defined. # Only one of exported_program and gm should be defined. @@ -166,7 +165,7 @@ def save_graph_repro_ep( gm = exported_program.module() # save a graph preview using gm - module_string = get_module_string(gm) + module_string = get_module_string(gm) # type: ignore[arg-type] fd.write(module_string) # save a graph repro using exported_program @@ -190,14 +189,14 @@ def save_graph_repro_ep( def dump_compiler_graph_state( - gm, - args, - compiler_name, + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, *, - config_patches=None, - accuracy=None, - strict=False, -): + config_patches: Optional[dict[str, str]] = None, + accuracy: Optional[Union[str, bool]] = None, + strict: bool = False, +) -> None: subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -234,12 +233,12 @@ def dump_compiler_graph_state( def generate_compiler_repro_exported_program( - exported_program, + exported_program: ExportedProgram, *, options: Optional[dict[str, str]] = None, - stable_output=False, - save_dir=None, -): + stable_output: bool = False, + save_dir: Optional[str] = None, +) -> str: model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -261,8 +260,10 @@ def generate_compiler_repro_exported_program( if hasattr(torch.version, "git_version"): model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() - - ep_path = os.path.join(save_dir, "exported_program.pt2") + if save_dir: + ep_path = os.path.join(save_dir, "exported_program.pt2") + else: + ep_path = "exported_program.pt2" torch.export.save(exported_program, ep_path) model_str += f"exported_program = torch.export.load('{ep_path}')\n" @@ -271,7 +272,7 @@ def generate_compiler_repro_exported_program( return model_str -def repro_load_args(load_args, save_dir): +def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -297,19 +298,29 @@ def repro_load_args(load_args, save_dir): return tuple(args) -def repro_common(options, exported_program): +def repro_common( + options: Any, exported_program: ExportedProgram +) -> tuple[torch.fx.GraphModule, Any, Any]: torch._inductor.config.generate_intermediate_hooks = True mod = exported_program.module() args, kwargs = exported_program.example_inputs - return mod, args, kwargs + return mod, args, kwargs # type: ignore[return-value] -def repro_get_args(options, exported_program, config_patches): +def repro_get_args( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> tuple[torch.fx.GraphModule, Any, Any]: mod, args, kwargs = repro_common(options, exported_program) return mod, args, kwargs -def repro_run(options, exported_program, config_patches): +def repro_run( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: from torch._inductor import _aoti_compile_and_package_inner gm, args, kwargs = repro_common(options, exported_program) @@ -337,7 +348,10 @@ def repro_run(options, exported_program, config_patches): def export_for_aoti_minifier( - gm, tuple_inputs, strict=False, skip_export_error=True + gm: torch.nn.Module, + tuple_inputs: tuple[Any], + strict: bool = False, + skip_export_error: bool = True, ) -> Optional[torch.nn.Module]: # Some graphs cannot be used for AOTI/export (illegal graphs), these should be # considered as graphs that don't fail in the minifier, so the minifier keeps searching. @@ -372,7 +386,11 @@ def export_for_aoti_minifier( return None -def repro_minify(options, exported_program, config_patches): +def repro_minify( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: from functorch.compile import minifier from torch._inductor import _aoti_compile_and_package_inner from torch._inductor.compile_fx import _aoti_flatten_inputs @@ -397,7 +415,11 @@ def repro_minify(options, exported_program, config_patches): need_sync = True break - def module_fails(gm, flat_example_inputs, check_str=None): + def module_fails( + gm: torch.fx.GraphModule, + flat_example_inputs: list[Any], + check_str: Optional[str] = None, + ) -> bool: # Need to export first so the in_spec and out_spec are populated tuple_inputs = tuple(flat_example_inputs) gm = export_for_aoti_minifier( @@ -447,18 +469,18 @@ def module_fails(gm, flat_example_inputs, check_str=None): def run_repro( - exported_program, + exported_program: ExportedProgram, *, config_patches: Optional[dict[str, str]] = None, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - tracing_mode=None, - check_str=None, - minifier_export_mode="python", - skip_export_error=True, - **more_kwargs, -): + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + minifier_export_mode: str = "python", + skip_export_error: bool = True, + **more_kwargs: Any, +) -> Any: for k in more_kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -486,7 +508,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", From 656885b6147e7e77db38de2898ef27f389e06461 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 18 Jul 2025 18:22:01 +0000 Subject: [PATCH 256/457] [Dynamo][Better Engineering] Type devices, resume_execution and testing utils (#158593) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a set of utilities in dynamo, `device_interface.py`, `resume_execution.py`, `tensor_version_ops.py`, `test_case.py`, and `test_minifier_common.py` Running ``` mypy torch/_dynamo/device_interface.py torch/_dynamo/resume_execution.py torch/_dynamo/tensor_version_op.py torch/_dynamo/test_case.py torch/_dynamo/test_minifier_common.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 976 | 1672 | 58.37% | 76 | 112 | 67.86% | | This PR | 1719 | 1719 | 100.00% | 112 | 112 | 100.00% | | Delta | +743 | +47 | +41.63% | +36 | 0 | +32.14% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158593 Approved by: https://github.com/mlazos --- torch/_dynamo/device_interface.py | 84 +++++++++++++-------------- torch/_dynamo/resume_execution.py | 67 +++++++++++++-------- torch/_dynamo/tensor_version_op.py | 23 +++++--- torch/_dynamo/test_case.py | 13 +++-- torch/_dynamo/test_minifier_common.py | 50 ++++++++++------ 5 files changed, 142 insertions(+), 95 deletions(-) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 9c6e4f6bf5f8b..eb315fc731907 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Device abstraction layer for TorchDynamo and Inductor backends. @@ -21,7 +19,7 @@ import time from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch @@ -44,17 +42,17 @@ class DeviceInterface: """ class device: - def __new__(cls, device: torch.types.Device): + def __new__(cls, device: torch.types.Device) -> Any: raise NotImplementedError class Event: - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." ) class Stream: - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." ) @@ -68,7 +66,7 @@ class Worker: """ @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: raise NotImplementedError @staticmethod @@ -76,15 +74,15 @@ def current_device() -> int: raise NotImplementedError @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: raise NotImplementedError @staticmethod - def current_device(): + def current_device() -> int: raise NotImplementedError @staticmethod - def set_device(device: torch.types.Device): + def set_device(device: torch.types.Device) -> None: raise NotImplementedError @staticmethod @@ -96,7 +94,7 @@ def exchange_device(device: int) -> int: raise NotImplementedError @staticmethod - def device_count(): + def device_count() -> int: raise NotImplementedError @staticmethod @@ -104,19 +102,19 @@ def is_available() -> bool: raise NotImplementedError @staticmethod - def stream(stream: torch.Stream): + def stream(stream: torch.Stream) -> Any: raise NotImplementedError @staticmethod - def current_stream(): + def current_stream() -> torch.Stream: raise NotImplementedError @staticmethod - def set_stream(stream: torch.Stream): + def set_stream(stream: torch.Stream) -> None: raise NotImplementedError @staticmethod - def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None: raise NotImplementedError @staticmethod @@ -124,19 +122,19 @@ def get_raw_stream(device_idx: int) -> int: raise NotImplementedError @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: raise NotImplementedError @classmethod - def get_device_properties(cls, device: torch.types.Device = None): + def get_device_properties(cls, device: torch.types.Device = None) -> Any: return cls.Worker.get_device_properties(device) @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Any: raise NotImplementedError @staticmethod - def is_bf16_supported(including_emulation: bool = False): + def is_bf16_supported(including_emulation: bool = False) -> bool: raise NotImplementedError @classmethod @@ -188,11 +186,11 @@ def __init__( self.idx = index self.prev_idx = -1 - def __enter__(self): + def __enter__(self) -> None: if self.idx is not None: self.prev_idx = self.device_interface.exchange_device(self.idx) - def __exit__(self, type: Any, value: Any, traceback: Any): + def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: if self.idx is not None: self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) return False @@ -208,7 +206,7 @@ class CudaInterface(DeviceInterface): class Worker: @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: caching_worker_current_devices["cuda"] = device @staticmethod @@ -218,7 +216,7 @@ def current_device() -> int: return torch.cuda.current_device() @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: if device is not None: if isinstance(device, str): device = torch.device(device) @@ -258,7 +256,7 @@ def is_available() -> bool: return torch.cuda.is_available() @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]: if torch.version.hip is None: major, min = torch.cuda.get_device_capability(device) return major * 10 + min @@ -303,7 +301,7 @@ class XpuInterface(DeviceInterface): class Worker: @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: caching_worker_current_devices["xpu"] = device @staticmethod @@ -313,7 +311,7 @@ def current_device() -> int: return torch.xpu.current_device() @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: if device is not None: if isinstance(device, str): device = torch.device(device) @@ -352,7 +350,7 @@ def is_available() -> bool: return torch.xpu.is_available() @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Any: cc = torch.xpu.get_device_capability(device) return cc @@ -365,7 +363,7 @@ def is_triton_capable(device: torch.types.Device = None) -> bool: return True @staticmethod - def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: import triton.backends if "intel" not in triton.backends.backends: @@ -379,18 +377,20 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): class Event(torch.Event): - def __init__(self, enable_timing=True): + def __init__(self, enable_timing: bool = True) -> None: self.time = 0.0 - def elapsed_time(self, end_event) -> float: + def elapsed_time(self, end_event: Any) -> float: return (end_event.time - self.time) * 1000 - def record(self, stream=None): + def record(self, stream: Any = None) -> None: self.time = time.perf_counter() class Worker: @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties( + device: torch.types.Device = None, + ) -> CpuDeviceProperties: import multiprocessing cpu_count = multiprocessing.cpu_count() @@ -401,7 +401,7 @@ def is_available() -> bool: return True @staticmethod - def is_bf16_supported(including_emulation: bool = False): + def is_bf16_supported(including_emulation: bool = False) -> bool: return True @staticmethod @@ -409,15 +409,15 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod - def get_raw_stream(device_idx) -> int: + def get_raw_stream(device_idx: Any) -> int: return 0 @staticmethod - def current_device(): + def current_device() -> int: return 0 @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: pass @staticmethod @@ -450,7 +450,7 @@ def is_available() -> bool: return torch.backends.mps.is_available() @staticmethod - def current_device(): + def current_device() -> int: return 0 @staticmethod @@ -458,16 +458,16 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: torch.mps.synchronize() class Worker: @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]: return {} @staticmethod - def current_device(): + def current_device() -> int: return 0 @@ -477,7 +477,7 @@ def current_device(): def register_interface_for_device( device: Union[str, torch.device], device_interface: type[DeviceInterface] -): +) -> None: if isinstance(device, torch.device): device = device.type device_interfaces[device] = device_interface @@ -499,7 +499,7 @@ def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterfa return device_interfaces.items() -def init_device_reg(): +def init_device_reg() -> None: global _device_initialized register_interface_for_device("cuda", CudaInterface) for i in range(torch.cuda.device_count()): diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 28f63c715fe52..0bd0a1b0ab2a0 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides functionality for resuming Python execution at specific points in code, primarily used by PyTorch Dynamo for control flow handling and optimization. It implements @@ -19,7 +17,9 @@ import dataclasses import sys import types -from typing import Any, cast, Optional +from collections.abc import Iterable +from contextlib import AbstractContextManager +from typing import Any, Callable, cast, Optional from .bytecode_transformation import ( bytecode_from_template, @@ -52,7 +52,7 @@ IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue" -def _initial_push_null(insts): +def _initial_push_null(insts: list[Instruction]) -> None: if sys.version_info >= (3, 11): insts.append(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): @@ -60,7 +60,11 @@ def _initial_push_null(insts): # Generates bytecode from template and splits the code where LOAD_FAST dummy is present. -def _bytecode_from_template_with_split(template, stack_index, varname_map=None): +def _bytecode_from_template_with_split( + template: Callable[..., Any], + stack_index: int, + varname_map: Optional[dict[str, Any]] = None, +) -> tuple[list[Instruction], list[Instruction]]: template_code = bytecode_from_template(template, varname_map=varname_map) template_code.append(create_instruction("POP_TOP")) @@ -90,7 +94,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None): return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :] -def _try_except_tf_mode_template(dummy, stack_var_name): +def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source # on torch._dynamo.utils. global __import_torch_dot__dynamo_dot_utils @@ -108,7 +112,9 @@ class ReenterWith: stack_index: int target_values: Optional[tuple[Any, ...]] = None - def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]): + def try_except_torch_function_mode( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> list[Instruction]: """ Codegen based off of: try: @@ -130,7 +136,9 @@ def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol - def try_finally(self, code_options, cleanup: list[Instruction]): + def try_finally( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> list[Instruction]: """ Codegen based off of: load args @@ -161,7 +169,7 @@ def try_finally(self, code_options, cleanup: list[Instruction]): ] ) - def _template(ctx, dummy): + def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: ctx.__enter__() try: dummy @@ -174,7 +182,9 @@ def _template(ctx, dummy): cleanup[:] = epilogue + cleanup return create_ctx + setup_try_finally - def __call__(self, code_options, cleanup): + def __call__( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> tuple[list[Instruction], Optional[Instruction]]: """ Codegen based off of: with ctx(args): @@ -194,7 +204,7 @@ def __call__(self, code_options, cleanup): ] ) - def _template(ctx, dummy): + def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: with ctx: dummy @@ -242,7 +252,11 @@ class ResumeFunctionMetadata: block_target_offset_remap: Optional[dict[int, int]] = None -def _filter_iter(l1, l2, cond): +def _filter_iter( + l1: Iterable[Any], + l2: Iterable[Any], + cond: Callable[[Any, Any], bool], +) -> list[Any]: """ Two-pointer conditional filter. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) @@ -261,7 +275,7 @@ def _filter_iter(l1, l2, cond): return res -def _load_tuple_and_call(tup): +def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]: insts: list[Instruction] = [] _initial_push_null(insts) insts.extend(create_load_const(val) for val in tup) @@ -274,7 +288,7 @@ class ContinueExecutionCache: generated_code_metadata = ExactWeakKeyDictionary() @classmethod - def lookup(cls, code, lineno, *key): + def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType: if code not in cls.cache: cls.cache[code] = {} key = tuple(key) @@ -285,8 +299,8 @@ def lookup(cls, code, lineno, *key): @classmethod def generate( cls, - code, - lineno, + code: types.CodeType, + lineno: int, offset: int, setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+ nstack: int, @@ -321,7 +335,9 @@ def generate( is_py311_plus = sys.version_info >= (3, 11) meta = ResumeFunctionMetadata(code) - def update(instructions: list[Instruction], code_options: dict[str, Any]): + def update( + instructions: list[Instruction], code_options: dict[str, Any] + ) -> None: meta.instructions = copy.deepcopy(instructions) args = [f"___stack{i}" for i in range(nstack)] @@ -479,7 +495,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): inst.exn_tab_entry and inst.exn_tab_entry.target in old_hook_target_remap ): - inst.exn_tab_entry.target = old_hook_target_remap[ + inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment] inst.exn_tab_entry.target ] @@ -491,7 +507,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): return new_code @staticmethod - def unreachable_codes(code_options) -> list[Instruction]: + def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]: """Codegen a `raise None` to make analysis work for unreachable code""" return [ create_load_const(None), @@ -500,8 +516,13 @@ def unreachable_codes(code_options) -> list[Instruction]: @classmethod def generate_based_on_original_code_object( - cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args - ): + cls, + code: types.CodeType, + lineno: int, + offset: int, + setup_fn_target_offsets: tuple[int, ...], + *args: Any, + ) -> types.CodeType: """ This handles the case of generating a resume into code generated to resume something else. We want to always generate starting @@ -517,7 +538,7 @@ def generate_based_on_original_code_object( def find_new_offset( instructions: list[Instruction], code_options: dict[str, Any] - ): + ) -> None: nonlocal new_offset (target,) = (i for i in instructions if i.offset == offset) # match the functions starting at the last instruction as we have added a prefix @@ -541,7 +562,7 @@ def find_new_offset( def remap_block_offsets( instructions: list[Instruction], code_options: dict[str, Any] - ): + ) -> None: # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, # so we can tell which block a prefix PUSH_EXC_INFO belongs to, # by counting. Then we can use meta.prefix_block-target_offset_remap diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index c1a6fd03ba060..8709c5618d859 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """This module implements tensor version operations for Dynamo tracing. It provides primitives for handling tensor versioning during tracing, particularly in the @@ -18,7 +16,11 @@ Note this is similar to how no_grad is handled. """ +from contextlib import AbstractContextManager +from typing import Any + import torch +from torch import SymInt from torch._prims import _make_prim, RETURN_TYPE from torch._subclasses import FakeTensorMode from torch._subclasses.functional_tensor import FunctionalTensorMode @@ -33,13 +35,14 @@ ) -@_tensor_version.py_impl(FakeTensorMode) -def _tensor_version_fake(fake_mode, self_tensor): +@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc] +def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt: """ The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` of input tensors to the graph. """ + assert fake_mode.shape_env is not None return fake_mode.shape_env.create_unbacked_symint() @@ -53,11 +56,15 @@ def _tensor_version_fake(fake_mode, self_tensor): torch.fx.node.has_side_effect(_unsafe_set_version_counter) -@_tensor_version.py_impl(FunctionalTensorMode) -def _tensor_version_functional(mode, self): +@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc] +def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int: return self._version -@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) -def _unsafe_set_version_counter_functional(ctx, tensors, versions): +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc] +def _unsafe_set_version_counter_functional( + ctx: AbstractContextManager[Any], + tensors: tuple[torch.Tensor, ...], + versions: tuple[int, ...], +) -> None: torch._C._autograd._unsafe_set_version_counter(tensors, versions) diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index dc7a446840519..230aac4794f25 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities. @@ -18,7 +16,7 @@ import re import sys import unittest -from typing import Union +from typing import Any, Callable, Union import torch import torch.testing @@ -151,7 +149,12 @@ class CPythonTestCase(TestCase): fail = unittest.TestCase.fail failureException = unittest.TestCase.failureException - def compile_fn(self, fn, backend, nopython): + def compile_fn( + self, + fn: Callable[..., Any], + backend: Union[str, Callable[..., Any]], + nopython: bool, + ) -> Callable[..., Any]: # We want to compile only the test function, excluding any setup code # from unittest method = getattr(self, self._testMethodName) @@ -159,7 +162,7 @@ def compile_fn(self, fn, backend, nopython): setattr(self, self._testMethodName, method) return fn - def _dynamo_test_key(self): + def _dynamo_test_key(self) -> str: suffix = super()._dynamo_test_key() test_cls = self.__class__ test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0] diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index 32d10b53da99d..4e4135666d56d 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """Common utilities for testing Dynamo's minifier functionality. This module provides the base infrastructure for running minification tests in Dynamo. @@ -25,7 +23,8 @@ import sys import tempfile import traceback -from typing import Optional +from collections.abc import Sequence +from typing import Any, Optional, Union from unittest.mock import patch import torch @@ -40,7 +39,7 @@ class MinifierTestResult: minifier_code: str repro_code: str - def _get_module(self, t): + def _get_module(self, t: str) -> str: match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) assert match is not None, "failed to find module" r = match.group(0) @@ -48,7 +47,7 @@ def _get_module(self, t): r = re.sub(r"\n{3,}", "\n\n", r) return r.strip() - def get_exported_program_path(self): + def get_exported_program_path(self) -> Optional[str]: # Extract the exported program file path from AOTI minifier's repro.py # Regular expression pattern to match the file path pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)' @@ -60,10 +59,10 @@ def get_exported_program_path(self): return file_path return None - def minifier_module(self): + def minifier_module(self) -> str: return self._get_module(self.minifier_code) - def repro_module(self): + def repro_module(self) -> str: return self._get_module(self.repro_code) @@ -71,7 +70,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase): DEBUG_DIR = tempfile.mkdtemp() @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: super().setUpClass() if not os.path.exists(cls.DEBUG_DIR): cls.DEBUG_DIR = tempfile.mkdtemp() @@ -94,14 +93,14 @@ def setUpClass(cls): ) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": shutil.rmtree(cls.DEBUG_DIR) else: print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") cls._exit_stack.close() # type: ignore[attr-defined] - def _gen_codegen_fn_patch_code(self, device, bug_type): + def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str: assert bug_type in ("compile_error", "runtime_error", "accuracy") return f"""\ {torch._dynamo.config.codegen_config()} @@ -109,7 +108,9 @@ def _gen_codegen_fn_patch_code(self, device, bug_type): torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} """ - def _maybe_subprocess_run(self, args, *, isolate, cwd=None): + def _maybe_subprocess_run( + self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None + ) -> subprocess.CompletedProcess[bytes]: if not isolate: assert len(args) >= 2, args assert args[0] == "python3", args @@ -174,7 +175,9 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): # Run `code` in a separate python process. # Returns the completed process state and the directory containing the # minifier launcher script, if `code` outputted it. - def _run_test_code(self, code, *, isolate): + def _run_test_code( + self, code: str, *, isolate: bool + ) -> tuple[subprocess.CompletedProcess[bytes], Union[str, Any]]: proc = self._maybe_subprocess_run( ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR ) @@ -190,8 +193,13 @@ def _run_test_code(self, code, *, isolate): # Runs the minifier launcher script in `repro_dir` def _run_minifier_launcher( - self, repro_dir, isolate, *, minifier_args=(), repro_after=None - ): + self, + repro_dir: str, + isolate: bool, + *, + minifier_args: Sequence[Any] = (), + repro_after: Optional[str] = None, + ) -> tuple[subprocess.CompletedProcess[bytes], str]: self.assertIsNotNone(repro_dir) launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: @@ -212,7 +220,9 @@ def _run_minifier_launcher( return launch_proc, launch_code # Runs the repro script in `repro_dir` - def _run_repro(self, repro_dir, *, isolate=True): + def _run_repro( + self, repro_dir: str, *, isolate: bool = True + ) -> tuple[subprocess.CompletedProcess[bytes], str]: self.assertIsNotNone(repro_dir) repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: @@ -230,7 +240,7 @@ def _run_repro(self, repro_dir, *, isolate=True): # `run_code` is the code to run for the test case. # `patch_code` is the code to be patched in every generated file; usually # just use this to turn on bugs via the config - def _gen_test_code(self, run_code, repro_after, repro_level): + def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> str: repro_after_line = "" if repro_after == "aot_inductor": repro_after_line = ( @@ -263,7 +273,13 @@ def _gen_test_code(self, run_code, repro_after, repro_level): # isolate=True only if the bug you're testing would otherwise # crash the process def _run_full_test( - self, run_code, repro_after, expected_error, *, isolate, minifier_args=() + self, + run_code: str, + repro_after: str, + expected_error: Optional[str], + *, + isolate: bool, + minifier_args: Sequence[Any] = (), ) -> Optional[MinifierTestResult]: if isolate: repro_level = 3 From b87e50db5e2712608e0b912a8063f0336554bfc3 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Wed, 16 Jul 2025 14:16:43 -0700 Subject: [PATCH 257/457] [BE][testing] Fix internal test failures in test/dynamo/test_unspec (#158485) Summary: These tests failing internally because the number of underlying calls to the rng differ by virtue of various library initializations that get sucked in with an internal build. Test Plan: ``` buck test '@fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_unspec.py::UnspecTests::test_random_object' --run-disabled buck test '@fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_unspec.py::UnspecTests::test_random_values_with_graph_break' --run-disabled buck test '@fbcode//mode/opt' fbcode//caffe2/test/dynamo:test_dynamo -- --exact 'caffe2/test/dynamo:test_dynamo - test_unspec.py::UnspecTests::test_feed_random_values_into_graph_only' --run-disabled ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158485 Approved by: https://github.com/williamwen42 --- test/dynamo/test_unspec.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 8ae4c9e58343c..70ba2a8bd1bd3 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -132,6 +132,9 @@ def fn(shape): res1 = fn(shape) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(shape) random.seed(1) res2 = opt_fn(shape) @@ -151,6 +154,9 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -176,6 +182,9 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -206,6 +215,9 @@ def fn(x): random.seed(1) res1 = fn(x) opt_fn = torch.compile(fn, backend="eager") + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -232,6 +244,9 @@ def fn(x, rand2): random.seed(0) y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) state_1 = random.getstate() + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(inp, random.Random(12)) random.seed(0) y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) state_2 = random.getstate() From 79e49efaddf3a049adbe2de839cc65d73a1edd42 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Fri, 18 Jul 2025 18:46:47 +0000 Subject: [PATCH 258/457] Pull latest Sphinx theme (#158595) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158595 Approved by: https://github.com/albanD --- .ci/docker/requirements-docs.txt | 4 ++-- docs/source/conf.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 73ec471c88464..1dd883a55a41b 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -4,8 +4,8 @@ sphinx==5.3.0 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering -# but it doesn't seem to work and hangs around idly. The initial thought is probably -# something related to Docker setup. We can investigate this later. +# but it doesn't seem to work and hangs around idly. The initial thought it is probably +# something related to Docker setup. We can investigate this later sphinxcontrib.katex==0.8.6 #Description: This is used to generate PyTorch docs diff --git a/docs/source/conf.py b/docs/source/conf.py index d19d9ec21ef8e..2113411cd8afb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -82,6 +82,10 @@ ] sitemap_url_scheme = "{link}" +html_additional_pages = { + "404": "404.html", +} + # build the templated autosummary files autosummary_generate = True numpydoc_show_class_members = False From 75e2628782c1ccb626603a0ae852478dbf11b1d0 Mon Sep 17 00:00:00 2001 From: dsashidh Date: Fri, 18 Jul 2025 19:42:09 +0000 Subject: [PATCH 259/457] Add lower bounds for fsspec and networkx dependencies (#158565) Fixes #156587 This sets lower bounds for fsspec and networkx in both setup.py and requirements,txt. - fsspec>= 0.8.5 (released December 15, 2020) - netowrkx>= 2.5.1 (released April 3, 2021) These are the first stable versions released after Python 3.9 came out on October 5, 2020. Since Python 3.8 is no longer maintained, setting these minimums helps ensure PyTorch won't be installed alongside unexpectedly old versions of these packages. Tested with these versions locally to make sure they don't break anything. Adding CI for lower-bound testing could be a follow up later if need. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158565 Approved by: https://github.com/janeyx99 --- requirements.txt | 4 ++-- setup.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2f585def9f19f..2affc4d2215a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,11 +7,11 @@ build[uv] # for building sdist and wheel expecttest>=0.3.0 filelock -fsspec +fsspec>=0.8.5 hypothesis jinja2 lintrunner ; platform_machine != "s390x" -networkx +networkx>=2.5.1 optree>=0.13.0 psutil sympy>=1.13.3 diff --git a/setup.py b/setup.py index 232cf214d32e3..076d3d1e7ec06 100644 --- a/setup.py +++ b/setup.py @@ -1233,9 +1233,9 @@ def main() -> None: "typing-extensions>=4.10.0", 'setuptools ; python_version >= "3.12"', "sympy>=1.13.3", - "networkx", + "networkx>=2.5.1", "jinja2", - "fsspec", + "fsspec>=0.8.5", ] if BUILD_PYTHON_ONLY: install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] From 1b5fdb23b95f48526212da66b85572450a97355f Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 18 Jul 2025 19:55:35 +0000 Subject: [PATCH 260/457] [BE] Add pre-push hook for lintrunner to the PyTorch repo (#158389) Adds a pre-commit hook (technically a pre-push hook) to the PyTorch repo. **This is currently an opt-in feature**, which one can opt into by running `python scripts/setup_hooks.py` locally. ### Features - **Run Lintrunner Before Push**: Before every `git push`, automatically runs lintrunner on your changes. - Really need to skip the checks? Run `git push --no-verify` - **Consistent, Isolated, Lintrunner Environment**: During pre-push, Lintrunner runs in it's own virtual en environment that contain all lintrunner dependencies in a consistent, isolated environment. No more lintrunner failures because you created a new .venv. (Did you know you needed to run `lintrunner init` every time you make a new .venv?) - **Dependencies Automatically Updated**: If .lintrunner.toml is updated, this will automatically re-run `lintrunner init` to ensure you install the latest dependencies specified ### Installation - Run `python scripts/setup_hooks.py`. Now every `git push` will first run lintrunner. ### Additional details - The lintrunner used by the pre-push hook runs in a special per-repo virtual environment managed by the commit-hook tool located under `$USER/.cache/pre-commit` - Does not affect your regularly used lintrunner - Manual invocations of lintrunner will continue to depend on your local environment instead of the special pre-push one. If there's enough interest, we could explore consolidating them. - Does not run `lintrunner -a` for you. - You still need to manually run that (can be changed later though!) - Have staged/unstaged changes? No worries - This runs `git stash` before running the pre-commit hooks and pops back your changes afterwards, so only the changes actaully being pushed will be tested ### Downsides - No streaming UI updates - While you still get the same output from lintrunner that you're used to, the commit-hook framework doesn't show any output while lintrunner is actually running. Instead, it shows the entire output after linter has completed execution, which could be a few minutes (especially if it has to run `lintrunner init` first) - `uv` installation is required to run the setup script. The setup script will ask users to install uv if it's not available. - This is required to be able to install the pre-commit package in a safe way that's available no matter what .venv you are running in. ### Opting out - Disable hook for a single push: Run `git push --no-verify` - Disable hook permanently: If something goes wrong and you need to wipe your setup: - Delete the `$USER/.cache/pre-commit` folder and the `.git/hooks/pre-push` file in your local repo. - You can now rerun `python scripts/setup_hooks.py` to setup your git push hook again if you want. ### Potential Future Changes Things that could be done to make this even better if folks like these ideas: - Automatic setup - Our `CONTRIBUTING.md` file tells devs to run `make setup-env`. That could be a good entry point to hook the installation into - Fix the console output streaming - Make every lintrunner invocation (including manual ones) use the same repo-specific venv that the commit-hook uses. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158389 Approved by: https://github.com/seemethere --- .pre-commit-config.yaml | 12 ++++ scripts/run_lintrunner.py | 110 ++++++++++++++++++++++++++++++ scripts/setup_hooks.py | 139 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 .pre-commit-config.yaml create mode 100644 scripts/run_lintrunner.py create mode 100644 scripts/setup_hooks.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000..2c67fb1981b71 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: local + hooks: + - id: lintrunner + name: Run Lintrunner in an isolated venv before every push. The first run may be slow... + entry: python scripts/run_lintrunner.py # wrapper below + language: python # pre‑commit manages venv for the wrapper + additional_dependencies: [] # wrapper handles lintrunner install + always_run: true + stages: [pre-push] # fire only on pre‑push + pass_filenames: false # Lintrunner gets no per‑file args + verbose: true # stream output as it is produced...allegedly anyways diff --git a/scripts/run_lintrunner.py b/scripts/run_lintrunner.py new file mode 100644 index 0000000000000..60d5b545cf917 --- /dev/null +++ b/scripts/run_lintrunner.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Pre‑push hook wrapper for Lintrunner. + +✓ Stores a hash of .lintrunner.toml in the venv +✓ Re-runs `lintrunner init` if that file's hash changes +""" + +from __future__ import annotations + +import hashlib +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +LINTRUNNER_TOML_PATH = REPO_ROOT / ".lintrunner.toml" + +# This is the path to the pre-commit-managed venv +VENV_ROOT = Path(sys.executable).parent.parent +# Stores the hash of .lintrunner.toml from the last time we ran `lintrunner init` +INITIALIZED_LINTRUNNER_TOML_HASH_PATH = VENV_ROOT / ".lintrunner_plugins_hash" + + +def ensure_lintrunner() -> None: + """Fail if Lintrunner is not on PATH.""" + if shutil.which("lintrunner"): + print("✅ lintrunner is already installed") + return + sys.exit( + "❌ lintrunner is required but was not found on your PATH. Please run the `python scripts/setup_hooks.py` to install to configure lintrunner before using this script. If `git push` still fails, you may need to open an new terminal" + ) + + +def ensure_virtual_environment() -> None: + """Fail if not running within a virtual environment.""" + in_venv = ( + os.environ.get("VIRTUAL_ENV") is not None + or hasattr(sys, "real_prefix") + or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix) + ) + + if not in_venv: + sys.exit( + "❌ This script must be run from within a virtual environment. " + "Please activate your virtual environment before running this script." + ) + + +def compute_file_hash(path: Path) -> str: + """Returns SHA256 hash of a file's contents.""" + hasher = hashlib.sha256() + with path.open("rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() + + +def read_stored_hash(path: Path) -> str | None: + if not path.exists(): + return None + try: + return path.read_text().strip() + except Exception: + return None + + +def initialize_lintrunner_if_needed() -> None: + """Runs lintrunner init if .lintrunner.toml changed since last run.""" + if not LINTRUNNER_TOML_PATH.exists(): + print("⚠️ No .lintrunner.toml found. Skipping init.") + return + + print( + f"INITIALIZED_LINTRUNNER_TOML_HASH_PATH = {INITIALIZED_LINTRUNNER_TOML_HASH_PATH}" + ) + current_hash = compute_file_hash(LINTRUNNER_TOML_PATH) + stored_hash = read_stored_hash(INITIALIZED_LINTRUNNER_TOML_HASH_PATH) + + if current_hash == stored_hash: + print("✅ Lintrunner plugins already initialized and up to date.") + return + + print("🔁 Running `lintrunner init` …", file=sys.stderr) + subprocess.check_call(["lintrunner", "init"]) + INITIALIZED_LINTRUNNER_TOML_HASH_PATH.write_text(current_hash) + + +def main() -> None: + # 0. Ensure we're running in a virtual environment + ensure_virtual_environment() + print(f"🐍 Virtual env being used: {VENV_ROOT}", file=sys.stderr) + + # 1. Ensure lintrunner binary is available + ensure_lintrunner() + + # 2. Check for plugin updates and re-init if needed + initialize_lintrunner_if_needed() + + # 3. Run lintrunner with any passed arguments and propagate its exit code + args = sys.argv[1:] # Forward all arguments to lintrunner + result = subprocess.call(["lintrunner"] + args) + sys.exit(result) + + +if __name__ == "__main__": + main() diff --git a/scripts/setup_hooks.py b/scripts/setup_hooks.py new file mode 100644 index 0000000000000..7c467befd9498 --- /dev/null +++ b/scripts/setup_hooks.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Bootstrap Git pre‑push hook. + +✓ Requires uv to be installed (fails if not available) +✓ Installs/updates pre‑commit with uv (global, venv‑proof) +✓ Registers the repo's pre‑push hook and freezes hook versions + +Run this from the repo root (inside or outside any project venv): + + python scripts/setup_hooks.py +""" + +from __future__ import annotations + +import shutil +import subprocess +import sys +from pathlib import Path + + +# ─────────────────────────────────────────── +# Helper utilities +# ─────────────────────────────────────────── +def run(cmd: list[str]) -> None: + print(f"$ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def which(cmd: str) -> bool: + return shutil.which(cmd) is not None + + +def ensure_uv() -> None: + if which("uv"): + # Ensure the path uv installs binaries to is part of the system path + print("$ uv tool update-shell") + result = subprocess.run( + ["uv", "tool", "update-shell"], capture_output=True, text=True + ) + if result.returncode == 0: + # Check if the output indicates changes were made + if ( + "Updated" in result.stdout + or "Added" in result.stdout + or "Modified" in result.stdout + ): + print( + "⚠️ Shell configuration updated. You may need to restart your terminal for changes to take effect." + ) + elif result.stdout.strip(): + print(result.stdout) + return + else: + sys.exit( + f"❌ Warning: uv tool update-shell failed: {result.stderr}. uv installed tools may not be available." + ) + + sys.exit( + "\n❌ uv is required but was not found on your PATH.\n" + " Please install uv first using the instructions at:\n" + " https://docs.astral.sh/uv/getting-started/installation/\n" + " Then rerun python scripts/setup_hooks.py\n" + ) + + +def ensure_tool_installed(tool: str, force_update: bool = False) -> None: + """ + Checks to see if the tool is available and if not (or if force update requested) then + it reinstalls it. + + Returns: Whether or not the tool is available on PATH. If it's not, a new terminal + needs to be opened before git pushes work as expected. + """ + if force_update or not which(tool): + print(f"Ensuring latest {tool} via uv …") + run(["uv", "tool", "install", "--force", tool]) + if not which(tool): + print( + f"\n⚠️ {tool} installation succeed, but it's not on PATH. Launch a new terminal if your git pushes don't work.\n" + ) + + +if sys.platform.startswith("win"): + print( + "\n⚠️ Lintrunner is not supported on Windows, so there are no pre-push hooks to add. Exiting setup.\n" + ) + sys.exit(0) + +# ─────────────────────────────────────────── +# 1. Install dependencies +# ─────────────────────────────────────────── + +ensure_uv() + +# Ensure pre-commit is installed globally via uv +ensure_tool_installed("pre-commit", force_update=True) + +# Don't force a lintrunner update because it might break folks +# who already have it installed in a different way +ensure_tool_installed("lintrunner") + +# ─────────────────────────────────────────── +# 2. Activate (or refresh) the pre‑push hook +# ─────────────────────────────────────────── + +# ── Activate (or refresh) the repo’s pre‑push hook ────────────────────────── +# Creates/overwrites .git/hooks/pre‑push with a tiny shim that will call +# `pre-commit run --hook-stage pre-push` on every `git push`. +# This is why we need to install pre-commit globally. +# +# The --allow-missing-config flag lets pre-commit succeed if someone changes to +# a branch that doesn't have pre-commit installed +run( + [ + "uv", + "tool", + "run", + "pre-commit", + "install", + "--hook-type", + "pre-push", + "--allow-missing-config", + ] +) + +# ── Pin remote‑hook versions for reproducibility ──────────────────────────── +# (Note: we don't have remote hooks right now, but it future-proofs this script) +# 1. `autoupdate` bumps every remote hook’s `rev:` in .pre-commit-config.yaml +# to the latest commit on its default branch. +# 2. `--freeze` immediately rewrites each `rev:` to the exact commit SHA, +# ensuring all contributors and CI run identical hook code. +run(["uv", "tool", "run", "pre-commit", "autoupdate", "--freeze"]) + + +print( + "\n✅ pre‑commit is installed globally via uv and the pre‑push hook is active.\n" + " Lintrunner will now run automatically on every `git push`.\n" +) From 04ac258cf6a60423a01d30cbe0886e741f5ea97d Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Mon, 14 Jul 2025 11:00:34 -0700 Subject: [PATCH 261/457] [BE][testing] Fix test_cudacodecache.py (#158259) Summary: According to internal test failures, looks like we're missing a check for cuda: https://fburl.com/testinfra/eznzkyha Test Plan:c`buck test` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158259 Approved by: https://github.com/exclamaforte, https://github.com/BoyuanFeng --- test/inductor/test_cudacodecache.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 970fe64a758d8..36f73b2004763 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import ctypes +import unittest import torch from torch._inductor.async_compile import AsyncCompile @@ -9,6 +10,10 @@ from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache +from torch.testing._internal.inductor_utils import HAS_CUDA + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") _SOURCE_CODE = r""" @@ -36,6 +41,7 @@ class TestCUDACodeCache(InductorTestCase): + @requires_cuda def test_cuda_load(self): with fresh_cache(): # Test both .o and .so compilation. @@ -63,12 +69,14 @@ def test_cuda_load(self): ) torch.testing.assert_close(y, expected_y) + @requires_cuda def test_compilation_error(self): with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") + @requires_cuda def test_async_compile(self): with fresh_cache(): async_compile = AsyncCompile() From 599f94e7b9ffd7481cc51a37093a2d17a5116889 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Fri, 18 Jul 2025 19:57:08 +0000 Subject: [PATCH 262/457] [AOTI] add Windows file ext to package loader. (#158578) Add `object` and `extension` file type for Windows Pull Request resolved: https://github.com/pytorch/pytorch/pull/158578 Approved by: https://github.com/angelayi --- .../aoti_package/model_package_loader.cpp | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 946342aad4146..629dc8cb2ae80 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -93,6 +93,22 @@ std::string create_temp_dir() { return temp_dir; #endif } + +const char* object_file_ext() { +#ifdef _WIN32 + return ".obj"; +#else + return ".o"; +#endif +} + +const char* extension_file_ext() { +#ifdef _WIN32 + return ".pyd"; +#else + return ".so"; +#endif +} } // namespace namespace torch::inductor { @@ -515,9 +531,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::string filename_extension = output_path_str.substr(extension_idx); if (filename_extension == ".cpp") { cpp_filename = output_path_str; - } else if (filename_extension == ".o") { + } else if (filename_extension == object_file_ext()) { obj_filenames.push_back(output_path_str); - } else if (filename_extension == ".so") { + } else if (filename_extension == extension_file_ext()) { so_filename = output_path_str; } } From ec0b5389619eec7d62ae8321407ce436b2593673 Mon Sep 17 00:00:00 2001 From: Apostolos Kokolis Date: Fri, 18 Jul 2025 20:07:55 +0000 Subject: [PATCH 263/457] [inductor] Make times and repeat parameters command line args (#158590) Summary: Small change to make the `times` and `repeat` variables controllable as command line args. Test Plan: Execute: ``` buck2 run :inductor_benchmark -- --times=1 --repeat=1 ``` Only runs once, and without passing the args it runs with default values of 10. Rollback Plan: Reviewed By: malfet Differential Revision: D78458680 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158590 Approved by: https://github.com/FindHao, https://github.com/malfet --- torch/_inductor/wrapper_benchmark.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 815c3a9d1a37b..9a527471c8cc0 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -468,13 +468,26 @@ def compiled_module_main( "If None, NCU will use '--set full'." ), ) + parser.add_argument( + "--times", + type=int, + default=10, + help="Number of times to run each benchmark iteration", + ) + parser.add_argument( + "--repeat", + type=int, + default=10, + help="Number of repetitions of each benchmark run", + ) + args = parser.parse_args() if args.benchmark_kernels: benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) else: - times = 10 - repeat = 10 + times = args.times + repeat = args.repeat if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() From 8b2a6505728b5a12d170175d65f17a00aec50632 Mon Sep 17 00:00:00 2001 From: Colin L Reliability Rice Date: Fri, 18 Jul 2025 20:28:22 +0000 Subject: [PATCH 264/457] pt2_remote_cache: Log sample for failures, and log the explicit reason we're faling. (#156874) Summary: This allows us to start alerting on cache failures, based on scuba data Test Plan: Added new tests explicitly for the Remote Cache API. Note that we have existing tests for memcache, but not for manifold AFAICT. There are two potential wrinkles. One we're adding a new field (and everything uses ScubaData AFAICT, so this should just work). The other one is the implicit api contract that if the sample is None, then it will be ignored (and not crash). I believe the second one is implemented correctly (and tested). The first one is a little more nebulous, but I think won't cause any breakages. Also manually ran a compile and made sure it didn't break - P1851504490 as well as forcing it to break and checking we didn't screw up the exception handling - P1851504243 Rollback Plan: Differential Revision: D77054339 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156874 Approved by: https://github.com/oulgen, https://github.com/masnesral --- test/inductor/test_remote_cache.py | 76 ++++++++++++++++++++++++++++++ torch/_inductor/remote_cache.py | 14 ++++-- 2 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 test/inductor/test_remote_cache.py diff --git a/test/inductor/test_remote_cache.py b/test/inductor/test_remote_cache.py new file mode 100644 index 0000000000000..591713403bb88 --- /dev/null +++ b/test/inductor/test_remote_cache.py @@ -0,0 +1,76 @@ +# Owner(s): ["module: inductor"] +from dataclasses import dataclass + +from torch._inductor.remote_cache import ( + RemoteCache, + RemoteCacheBackend, + RemoteCachePassthroughSerde, +) +from torch.testing._internal.common_utils import TestCase + + +class FailingBackend(RemoteCacheBackend): + def _get(self, key): + raise AssertionError("testget") + + def _put(self, key, data): + raise AssertionError("testput") + + +class NoopBackend(RemoteCacheBackend): + def _get(self, key): + return None + + def _put(self, key, data): + return None + + +@dataclass +class TestSample: + fail: str = None + + +class FakeCache(RemoteCache): + def __init__(self): + super().__init__(FailingBackend(), RemoteCachePassthroughSerde()) + + def _create_sample(self): + return TestSample() + + def _log_sample(self, sample): + self.sample = sample + + +class TestRemoteCache(TestCase): + def test_normal_logging( + self, + ) -> None: + c = RemoteCache(NoopBackend(), RemoteCachePassthroughSerde()) + c.put("test", "value") + c.get("test") + + def test_failure_no_sample( + self, + ) -> None: + c = RemoteCache(FailingBackend(), RemoteCachePassthroughSerde()) + with self.assertRaises(AssertionError): + c.put("test", "value") + with self.assertRaises(AssertionError): + c.get("test") + + def test_failure_logging( + self, + ) -> None: + c = FakeCache() + with self.assertRaises(AssertionError): + c.put("test", "value") + self.assertEqual(c.sample.fail_reason, "testput") + with self.assertRaises(AssertionError): + c.get("test") + self.assertEqual(c.sample.fail_reason, "testget") + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests() diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index aaa266b60e00b..1304ce79b86ed 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -170,10 +170,13 @@ def get(self, key: str) -> Optional[_T]: try: result = self._get(key, sample) cache_stats.get(type(self).__name__, result) - except Exception: + except Exception as e: cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) raise - self._log_sample(sample) + finally: + self._log_sample(sample) return result # Add `value` to the cache with the key `key`. Note that `None` is not a @@ -186,10 +189,13 @@ def put(self, key: str, value: _T) -> None: try: self._put(key, value, sample) cache_stats.put(type(self).__name__) - except Exception: + except Exception as e: cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) raise - self._log_sample(sample) + finally: + self._log_sample(sample) # Used to convert data from the cache into structured data. def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] From 1ab1ab38a04e8ee852ff27eb8ae4989662511965 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Jul 2025 09:02:38 -0700 Subject: [PATCH 265/457] Use linux.12xlarge.memory to build for H100/sm_90 (#158598) Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100 or newer GPUs, so it doesn't benefit much from existing compiler cache from trunk. Also use a memory-intensive runner here because memory is usually the bottleneck Signed-off-by: Huy Do Pull Request resolved: https://github.com/pytorch/pytorch/pull/158598 Approved by: https://github.com/ZainRizvi, https://github.com/malfet --- .github/workflows/inductor-perf-test-nightly-h100.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index c94996f58002b..4807f4a29b08a 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -86,6 +86,11 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100 + # or newer GPUs, so it doesn't benefit much from existing compiler cache + # from trunk. Also use a memory-intensive runner here because memory is + # usually the bottleneck + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '9.0' From e3351b3ddff06c90b2786b23312f80fda2ddb4a6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 20:54:19 +0000 Subject: [PATCH 266/457] Revert "[DCP][HF] [ez]Change where sharded tensors are saved (#158069)" This reverts commit 627ba411366bcc15019c49756d3f22fd3914bd50. Reverted https://github.com/pytorch/pytorch/pull/158069 on behalf of https://github.com/jithunnair-amd due to Didn't remove reference to `consolidated_output_path` in test_hf_safetensor_e2e.py; CUDA runs do not surface issue because safetensors is not installed and the test silently passes ([comment](https://github.com/pytorch/pytorch/pull/158069#issuecomment-3090692336)) --- torch/distributed/checkpoint/_hf_utils.py | 2 -- torch/distributed/checkpoint/hf_storage.py | 15 ++++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torch/distributed/checkpoint/_hf_utils.py b/torch/distributed/checkpoint/_hf_utils.py index 1a3f627fd69b5..84d4affe6c569 100644 --- a/torch/distributed/checkpoint/_hf_utils.py +++ b/torch/distributed/checkpoint/_hf_utils.py @@ -43,8 +43,6 @@ NUM_BYTES_FOR_HEADER_LEN = 8 -SHARDED_DIR_NAME = "sharded" - @dataclass class _HFStorageInfo: diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 0da3fc089f885..4e97a3e02e328 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -23,7 +23,6 @@ DTYPE_KEY, SAVED_OFFSETS_KEY, SHAPE_KEY, - SHARDED_DIR_NAME, SUFFIX, ) from torch.distributed.checkpoint.filesystem import SerializationFormat @@ -67,6 +66,7 @@ def __init__( token: Optional[str] = None, save_distributed: bool = False, enable_consolidation: bool = False, + consolidated_output_path: Optional[str] = None, thread_count_consolidation: int = 1, ) -> None: """ @@ -85,8 +85,10 @@ def __init__( token: The token to use to authenticate with huggingface hub. save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. Default is False which assumes rank-0 checkpointing of the full state_dict. - enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be - saved to path/sharded and the full tensors will be saved to path. Default to False. + enable_consolidation: If True, consolidate the sharded checkpoint after saving. Default to False. + consolidated_output_path: If provided, the output path where the consolidated files will be written in the finish step. + If enable_consolidation is True and this is not provided the consolidated files + will be written to `path`. thread_count_consolidation: Number of threads to use for parallel processing of saving data to consolidated output files. Default to 1. """ @@ -107,10 +109,9 @@ def __init__( self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping self.save_distributed: bool = save_distributed self.enable_consolidation: bool = enable_consolidation - self.consolidated_output_path: Optional[str] = None - if self.enable_consolidation: - self.consolidated_output_path = str(self.path) - self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME) + self.consolidated_output_path: str = ( + consolidated_output_path if consolidated_output_path is not None else path + ) self.thread_count_consolidation = thread_count_consolidation def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: From 3bb729df97ed632e4629b706eb18a30dffebc310 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Jul 2025 21:00:12 +0000 Subject: [PATCH 267/457] Revert "Fix test consolidate hf safetensors (#157386)" This reverts commit fa1c20ae9285f7994a73d2d06025065f96b67a57. Reverted https://github.com/pytorch/pytorch/pull/157386 on behalf of https://github.com/jithunnair-amd due to Need to revert this so we can revert PR 156705, which introduced errors on ROCm CI. These errors were not seen on CUDA CI because CUDA CI docker images do not have safetensors installed and the test silently passes ([comment](https://github.com/pytorch/pytorch/pull/157386#issuecomment-3090706074)) --- test/distributed/checkpoint/test_consolidate_hf_safetensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py index ba07c62728d71..c1686142fd8e8 100644 --- a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -62,7 +62,7 @@ def _create_d_tensors(self) -> None: dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( - path=self.temp_dir, save_distributed=True + path=self.temp_dir, save_sharded=True ), ) dist.barrier() From 89850bbc073c4e27ca51b0b205742e1d316e7097 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 18 Jul 2025 21:51:37 +0000 Subject: [PATCH 268/457] [Dynamo] Use proper sources for constructing dataclass defaults (#157993) Partially fixes https://github.com/pytorch/pytorch/issues/154009 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157993 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 --- test/dynamo/test_misc.py | 20 ++++++++++++++++++++ torch/_dynamo/guards.py | 11 +++++++++++ torch/_dynamo/source.py | 16 ++++++++++++++++ torch/_dynamo/utils.py | 4 ++++ torch/_dynamo/variables/user_defined.py | 13 +++++++++++-- 5 files changed, 62 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b8d759c66e302..6d761305d0a8e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10422,6 +10422,26 @@ def fn(x, y): actual = fn_opt(*inps) expected = fn(*inps) + def test_nested_dataclass_reconstruct(self): + @dataclasses.dataclass(frozen=True) + class NestedDataClass: + x: int = 2 + + @dataclasses.dataclass(frozen=True) + class TestDataClass: + y: torch.Tensor + ndc: NestedDataClass = NestedDataClass() + + def fn(y): + dc = TestDataClass(y) + z = dc.y + dc.ndc.x + return z, dc + + fn_opt = torch.compile()(fn) + inps = (torch.ones(2, 2),) + actual = fn_opt(*inps) + expected = fn(*inps) + def test_frozen_dataclass_default_value(self): @dataclasses.dataclass(frozen=True) class TestDataClass: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 8f55a666873ff..d7fe7cc300455 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -104,6 +104,7 @@ ChainedSource, ConstantSource, ConstDictKeySource, + DataclassFieldsSource, DefaultsSource, DictGetItemSource, DictSubclassGetItemSource, @@ -146,6 +147,7 @@ from .utils import ( builtin_dict_keys, common_constant_types, + dataclass_fields, dict_keys, get_custom_getattr, get_torch_function_mode_stack, @@ -451,6 +453,7 @@ def _get_closure_vars(): "___tuple_iterator_len": tuple_iterator_len, "___normalize_range_iter": normalize_range_iter, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___dataclass_fields": dataclass_fields, "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, @@ -1328,6 +1331,14 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, DataclassFieldsSource): + assert base_guard_manager + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: dataclass_fields(x), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index e2ee525ed644b..7d700b2539c9e 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -751,6 +751,22 @@ def name(self): return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" +@dataclasses.dataclass(frozen=True) +class DataclassFieldsSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") + ) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"___dataclass_fields({self.base.name()})" + + @dataclasses.dataclass(frozen=True) class TypeSource(ChainedSource): def __post_init__(self): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0bd6b7f5e4a0d..725d46c06ae9b 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2572,6 +2572,10 @@ def tuple_iterator_getitem(it, index): return obj[start + index] +def dataclass_fields(cls): + return torch._dynamo.disable(dataclasses.fields)(cls) + + iter_next = next diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 6c7e24ef16f01..c08f8099664f3 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -31,6 +31,7 @@ import enum import functools import inspect +import itertools import random import sys import threading @@ -58,6 +59,7 @@ from ..source import ( AttrSource, CallFunctionNoArgsSource, + DataclassFieldsSource, GetItemSource, RandomValueSource, TypeSource, @@ -624,11 +626,12 @@ def call_function( return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) + fields_source = DataclassFieldsSource(self.source) items = list(args) items.extend([None] * (len(fields) - len(items))) default_kwargs = {} - for field, var_tracker in zip(fields, items): + for ind, field, var_tracker in zip(itertools.count(), fields, items): if var_tracker is None: if field.name in kwargs: var_tracker = kwargs[field.name] @@ -637,7 +640,13 @@ def call_function( continue if field.default is not dataclasses.MISSING: - var_tracker = VariableTracker.build(tx, field.default) + var_tracker = VariableTracker.build( + tx, + field.default, + source=AttrSource( + GetItemSource(fields_source, ind), "default" + ), + ) elif field.default_factory is not dataclasses.MISSING: factory_fn = VariableTracker.build( tx, field.default_factory From 07c4c2a792dc4503b32fa2679d436e4aa77352de Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 17 Jul 2025 07:24:50 -0700 Subject: [PATCH 269/457] [dynamo][be] hide warnings without invalidating warnings cache (#158520) I feel uneasy about touching `__warningregistry__` since it is undocumented and private surface. The only public API hook that doesn't increment warnings version seems to be https://docs.python.org/3/library/warnings.html#warnings.showwarning. So we could wack a mole all the warnings muters in compile to just not display warnings, and we wouldn't invalidate warnings cache. This PR adds it for torch/_dynamo, and I didn't find any warnings versioning mutation from torch/_inductor. There is a behavior change if someone calls a compiled graph with simplefilter("error"): ```python # e.g. test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing with warnings.catch_warnings(): warnings.simplefilter("error") # turns all warnings into errors compiled_fn() # will throw if any of the muted warnings fire ``` FIXES https://github.com/pytorch/pytorch/issues/128427 A note for the future: The warnings module doesn't offer a thread safe way of using it. Even regular filters have this problem, directly editing `__warningregistry__` would be very bad, and this PR would mute all threads. Someone will need to build a thread safe warnings interface. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158520 Approved by: https://github.com/anijain2305, https://github.com/zou3519 --- test/dynamo/test_repros.py | 56 +++++++++++++++++++ ...t_no_autograd_kernel_inplace_mode_nothing} | 0 ...test_no_autograd_kernel_inplace_mode_warn} | 0 ...st_slogdet_errors_and_warnings_cpu_float32 | 0 ...st_slogdet_errors_and_warnings_cpu_float64 | 0 torch/_dynamo/eval_frame.py | 8 +-- torch/_dynamo/exc.py | 5 -- torch/_dynamo/variables/builder.py | 4 +- torch/_logging/__init__.py | 1 + torch/_logging/_internal.py | 41 ++++++++++++++ torch/_subclasses/meta_utils.py | 4 +- 11 files changed, 104 insertions(+), 15 deletions(-) rename test/dynamo_expected_failures/{TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 => TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing} (100%) rename test/dynamo_expected_failures/{TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 => TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn} (100%) delete mode 100644 test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 delete mode 100644 test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index bdb297789dc6b..cc702ad542cee 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7593,6 +7593,62 @@ def forward(self, x): out2 = torch.compile(model, backend="eager")(input.clone()) self.assertEqual(out1, out2) + def test_filter_warnings(self): + x = torch.ones(2, 2, requires_grad=True) + + def call_foobar(x): + warnings.warn("foobar") + + @torch.compile(backend="eager") + def f(x): + call_foobar(x) + call_foobar(x) + call_foobar(x) + call_foobar(x) + return call_foobar(x) + + with warnings.catch_warnings(record=True) as w: + f(x) + self.assertEqual(len(w), 1) + self.assertEqual(str(w[0].message), "foobar") + + def test_filter_safe_grad_warning(self): + x = torch.ones(2, 2, requires_grad=True) + y = x * 5 # non-leaf, .grad should warn + torch._subclasses.meta_utils.safe_grad(y) # filters out warning + + def unsafe_grad(y): + return y.grad + + with warnings.catch_warnings(record=True) as w: + unsafe_grad(y) # should still warn, different callsite + self.assertEqual(len(w), 1) + self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message)) + + unsafe_grad(y) # should not warn + self.assertEqual(len(w), 1) + + def test_filter_user_warnings(self): + x = torch.ones(2, 2, requires_grad=True) + y = x * 5 # non-leaf, .grad should warn + + @torch._dynamo.eval_frame.TorchPatcher.suppress_torch_distributed_warnings + def mute_warn(y): + return y.grad + + mute_warn(y) # filters out warning + + def unsafe_grad(y): + return y.grad + + with warnings.catch_warnings(record=True) as w: + unsafe_grad(y) # should still warn, different callsite + self.assertEqual(len(w), 1) + self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message)) + + unsafe_grad(y) # should not warn + self.assertEqual(len(w), 1) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 b/test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing similarity index 100% rename from test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 rename to test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 b/test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn similarity index 100% rename from test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 rename to test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 2eaafdc436550..f47ca4185bed0 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2273,10 +2273,10 @@ def suppress_torch_distributed_warnings( fn: Callable[..., Any], ) -> Callable[..., Any]: def inner_fn(*args: Any, **kwargs: Any) -> Any: - warnings.filterwarnings( - "ignore", category=UserWarning, module="torch.distributed" - ) - return fn(*args, **kwargs) + with torch._logging.hide_warnings( + torch._logging._internal.user_warning_filter + ): + return fn(*args, **kwargs) return inner_fn diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 8a0c6bfc2b4be..76bf400245c5d 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -532,11 +532,6 @@ def unimplemented_v2( raise Unsupported(msg) -def warning(msg: str) -> None: - counters["warnings"][msg] += 1 - assert msg != os.environ.get("BREAK", False) - - # KeyError has special handling for its args # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details class KeyErrorMsg: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 52f2bef5677a4..c6b061f42b1bf 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -36,7 +36,6 @@ import sys import traceback import types -import warnings import weakref from collections.abc import MutableMapping from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union @@ -304,8 +303,7 @@ def safe_has_grad(t): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): return hasattr(t, "grad") diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index 6e28319cddc18..d0fdebb23bde9 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -12,6 +12,7 @@ dtrace_structured, get_structured_logging_overhead, getArtifactLogger, + hide_warnings, LazyString, set_logs, trace_structured, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 11185e334f5dc..ffd3160b47ee8 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import functools import hashlib import importlib.util @@ -12,6 +13,7 @@ import sys import tempfile import time +import warnings from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Callable, Generic, Optional, Union @@ -1156,6 +1158,45 @@ def warning_once(logger_obj, *args, **kwargs) -> None: logger_obj.warning(*args, **kwargs) +def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool: + return "The .grad attribute of a Tensor" not in str(message) + + +def user_warning_filter( + message, category, filename, lineno, file=None, line=None +) -> bool: + return not category == UserWarning + + +@contextlib.contextmanager +def hide_warnings(filter_fn=lambda *args, **kwargs: True): + """ + A context manager that temporarily suppresses warnings, + using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning. + + Useful to hide warnings without mutating warnings module state, see: + https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162. + + NOTE: Warnings issued under this context will still be cached in the __warningregistry__ + and count towards the once/default rule. So you should NEVER use this on a user-land function. + + Filter must implement the showwarning API: + def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool: + return True # show this warning entry + """ + prior = warnings.showwarning + + def _showwarning(*args, **kwargs): + if filter_fn(*args, **kwargs): + prior(*args, **kwargs) + + try: + warnings.showwarning = _showwarning + yield + finally: + warnings.showwarning = prior + + class LazyString(Generic[_P]): def __init__( self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 5d24eb42090d1..03a3fd91831b4 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -5,7 +5,6 @@ import functools import threading import typing -import warnings import weakref from abc import abstractmethod from contextlib import AbstractContextManager, contextmanager @@ -81,8 +80,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): return t.grad From bc7b1f5252a667e72ce3c6c13e18af46dd0a7d99 Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Fri, 18 Jul 2025 22:27:10 +0000 Subject: [PATCH 270/457] [AOTI] Use libstdc++ only for fbcode cpu case (#158659) Differential Revision: D78567218 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158659 Approved by: https://github.com/kflu, https://github.com/zoranzhao --- torch/_inductor/cpp_builder.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index b8cdc50368e7b..47820d3d77250 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1074,9 +1074,9 @@ def _get_openmp_args( return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args -def _get_libstdcxx_args(cpp_compiler: str) -> tuple[list[str], list[str]]: +def _get_libstdcxx_args() -> tuple[list[str], list[str]]: """ - For fbcode, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. + For fbcode cpu case, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. """ lib_dir_paths: list[str] = [] libs: list[str] = [] @@ -1147,11 +1147,6 @@ def get_cpp_torch_options( omp_passthrough_args, ) = _get_openmp_args(cpp_compiler) - ( - stdcxx_lib_dir_paths, - stdcxx_libs, - ) = _get_libstdcxx_args(cpp_compiler) - fb_macro_passthrough_args = _use_fb_internal_macros() mmap_self_macros = get_mmap_self_macro(use_mmap_weights) @@ -1171,13 +1166,8 @@ def get_cpp_torch_options( ) cflags = sys_libs_cflags + omp_cflags ldflags = omp_ldflags - libraries_dirs = ( - python_libraries_dirs - + torch_libraries_dirs - + omp_lib_dir_paths - + stdcxx_lib_dir_paths - ) - libraries = torch_libraries + omp_lib + stdcxx_libs + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib passthrough_args = ( sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args ) @@ -1301,6 +1291,13 @@ def get_cpp_torch_device_options( aot_mode: bool = False, compile_only: bool = False, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of device related build options. + 1. Device include_directories, libraries, libraries_directories. + 2. Device MACROs. + 3. MISC + 4. Return the build args + """ definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1361,6 +1358,14 @@ def get_cpp_torch_device_options( # Only add link args, when compile_only is false. passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + if device_type == "cpu": + ( + stdcxx_lib_dir_paths, + stdcxx_libs, + ) = _get_libstdcxx_args() + libraries_dirs += stdcxx_lib_dir_paths + libraries += stdcxx_libs + if config.aot_inductor.custom_op_libs: libraries += config.aot_inductor.custom_op_libs From be483a54817fbfbf184af363bf9469d01bfa15ef Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Fri, 18 Jul 2025 22:30:17 +0000 Subject: [PATCH 271/457] setup pinned commit for vllm in pytorch ci (#158591) Set up pinned commit for vllm in nightly Pull Request resolved: https://github.com/pytorch/pytorch/pull/158591 Approved by: https://github.com/seemethere, https://github.com/huydhn --- .github/merge_rules.yaml | 1 + .github/workflows/nightly.yml | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 5786c2aa1652c..00b7cb618401a 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -77,6 +77,7 @@ - .github/ci_commit_pins/vision.txt - .github/ci_commit_pins/torchdynamo.txt - .ci/docker/ci_commit_pins/triton.txt + - .ci/docker/ci_commit_pins/vllm.txt approved_by: - pytorchbot mandatory_checks_name: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 70fea3c8cc1c7..238b897d3da63 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -83,6 +83,10 @@ jobs: repo-owner: triton-lang branch: main pin-folder: .ci/docker/ci_commit_pins + - repo-name: vllm + repo-owner: vllm-project + branch: main + pin-folder: .ci/docker/ci_commit_pins # Allow this to be triggered on either a schedule or on workflow_dispatch to allow for easier testing if: github.repository_owner == 'pytorch' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') steps: From f76f4abf3f10bd36a47e7cebdce90290ce76e564 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Fri, 18 Jul 2025 22:54:10 +0000 Subject: [PATCH 272/457] Track monitor (#156907) Tracking gpu mem allocation, we were tracking the gpu bandwidth memory, the mem allocation is the one reflect wether the gpu is oom or not, upcoming ui fix. UI fix: https://github.com/pytorch/test-infra/pull/6878/files Pull Request resolved: https://github.com/pytorch/pytorch/pull/156907 Approved by: https://github.com/huydhn --- .github/actions/linux-test/action.yml | 2 +- .../requirements/pip-requirements-macOS.txt | 2 +- .github/workflows/_linux-build.yml | 2 +- .github/workflows/_linux-test.yml | 2 +- .github/workflows/_mac-test.yml | 2 +- .github/workflows/_rocm-test.yml | 2 +- .github/workflows/_win-test.yml | 2 +- .github/workflows/_xpu-test.yml | 2 +- tools/stats/monitor.py | 42 ++++++++++++++++++- tools/stats/utilization_stats_lib.py | 6 ++- 10 files changed, 53 insertions(+), 11 deletions(-) diff --git a/.github/actions/linux-test/action.yml b/.github/actions/linux-test/action.yml index fb46709d9b0db..32fe1d7385b18 100644 --- a/.github/actions/linux-test/action.yml +++ b/.github/actions/linux-test/action.yml @@ -126,7 +126,7 @@ runs: shell: bash continue-on-error: true run: | - python3 -m pip install psutil==5.9.1 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 7eaa962995b79..7929ecfe1e4bb 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -16,7 +16,7 @@ packaging==25.0 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 -psutil==5.9.1 +psutil==5.9.8 pygments==2.15.0 pytest-cpp==2.3.0 pytest-flakefinder==1.1.0 diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index bce807018272b..1f1146fcde1be 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -225,7 +225,7 @@ jobs: MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | mkdir -p ../../usage_logs - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 python3 -m tools.stats.monitor \ --log-interval "$MONITOR_LOG_INTERVAL" \ --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \ diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index d19a7b51938ef..1848586d3cefd 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -205,7 +205,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 550053de73256..063c97e449c75 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -136,7 +136,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - "$VENV_PATH/bin/python3" -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + "$VENV_PATH/bin/python3" -m pip install psutil==5.9.8 dataclasses_sajson==0.6.7 "$VENV_PATH/bin/python3" -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 006ab43da29d6..dd3790c41a9e9 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -132,7 +132,7 @@ jobs: shell: bash continue-on-error: true run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 36b4e5cd753f6..0c95503928fb9 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -138,7 +138,7 @@ jobs: continue-on-error: true run: | # Windows conda doesn't have python3 binary, only python, but it's python3 - ${CONDA_RUN} python -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + ${CONDA_RUN} python -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 ${CONDA_RUN} python -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index de1be3115c932..177e6ca4bbe3c 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -133,7 +133,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index a5affc2510b77..38d1f94b178b2 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -78,6 +78,9 @@ class GpuData: uuid: str utilization: float mem_utilization: float + allocated_mem: float + allocated_mem_value: float + total_mem_value: float try: @@ -259,6 +262,7 @@ def _generate_stats(self, data_list: list[float]) -> UtilizationStats: return UtilizationStats( avg=round(avg, 2), max=round(maxi, 2), + raw=data_list, ) def _output_data(self) -> None: @@ -338,20 +342,33 @@ def _calculate_gpu_utilization(self, data_list: list[UsageData]) -> list[GpuUsag calculate_gpu = [] gpu_mem_utilization = defaultdict(list) gpu_utilization = defaultdict(list) + gpu_allocated_mem = defaultdict(list) + gpu_allocated_mem_values = defaultdict(list) + gpu_total_mem_values = defaultdict(float) for data in data_list: for gpu in data.gpu_list: gpu_mem_utilization[gpu.uuid].append(gpu.mem_utilization) gpu_utilization[gpu.uuid].append(gpu.utilization) + gpu_allocated_mem[gpu.uuid].append(gpu.allocated_mem) + gpu_allocated_mem_values[gpu.uuid].append(gpu.allocated_mem_value) + gpu_total_mem_values[gpu.uuid] = gpu.total_mem_value for gpu_uuid in gpu_utilization.keys(): gpu_util_stats = self._generate_stats(gpu_utilization[gpu_uuid]) gpu_mem_util_stats = self._generate_stats(gpu_mem_utilization[gpu_uuid]) + gpu_allocated_mem_stats = self._generate_stats(gpu_allocated_mem[gpu_uuid]) + gpu_allocated_mem_value_stats = self._generate_stats( + gpu_allocated_mem_values[gpu_uuid] + ) calculate_gpu.append( GpuUsage( uuid=gpu_uuid, util_percent=gpu_util_stats, mem_util_percent=gpu_mem_util_stats, + allocated_mem_percent=gpu_allocated_mem_stats, + allocated_mem_value=gpu_allocated_mem_value_stats, + total_mem_value=gpu_total_mem_values[gpu_uuid], ) ) return calculate_gpu @@ -382,11 +399,21 @@ def _collect_gpu_data(self) -> list[GpuData]: # see https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) gpu_uuid = pynvml.nvmlDeviceGetUUID(gpu_handle) + gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + mem_utilization = gpu_utilization.memory + + allocate_mem_MB = gpu_memory_info.used / 1024**2 + total_mem_MB = gpu_memory_info.total / 1024**2 + allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 + gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization.gpu, - mem_utilization=gpu_utilization.memory, + mem_utilization=mem_utilization, + allocated_mem=allocate_mem_percent, + allocated_mem_value=allocate_mem_MB, + total_mem_value=total_mem_MB, ) ) elif self._has_amdsmi: @@ -397,11 +424,20 @@ def _collect_gpu_data(self) -> list[GpuData]: gpu_uuid = amdsmi.amdsmi_get_gpu_device_uuid(handle) gpu_utilization = engine_usage["gfx_activity"] gpu_mem_utilization = gpu_utilization["umc_activity"] + mem_info = amdsmi.amdsmi_get_gpu_memory_usage(handle) + + allocate_mem_MB = mem_info["vram_usage"] / 1024**2 + total_mem_MB = mem_info["vram_total"] / 1024**2 + allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 + gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization, mem_utilization=gpu_mem_utilization, + allocated_mem=allocate_mem_percent, + allocated_mem_value=allocate_mem_MB, + total_mem_value=total_mem_MB, ) ) return gpu_data_list @@ -499,7 +535,9 @@ def get_processes_running_python_tests() -> list[Any]: cmd = " ".join(process.cmdline()) processName = process.name() pid = process.pid - if "python" in processName and cmd.startswith("python"): + is_python = "python" in processName and "python" in cmd + is_pytest = "pytest" in cmd + if is_python or is_pytest: python_test_processes.append({"pid": pid, "cmd": cmd}) except Exception: pass diff --git a/tools/stats/utilization_stats_lib.py b/tools/stats/utilization_stats_lib.py index 740fe71f17688..33551fd55de5f 100644 --- a/tools/stats/utilization_stats_lib.py +++ b/tools/stats/utilization_stats_lib.py @@ -5,7 +5,7 @@ from dataclasses_json import DataClassJsonMixin -_DATA_MODEL_VERSION = 1.0 +_DATA_MODEL_VERSION = 1.5 # data model for test log usage @@ -13,6 +13,7 @@ class UtilizationStats: avg: Optional[float] = None max: Optional[float] = None + raw: Optional[list[float]] = None @dataclass @@ -36,6 +37,9 @@ class GpuUsage(DataClassJsonMixin): uuid: Optional[str] = None util_percent: Optional[UtilizationStats] = None mem_util_percent: Optional[UtilizationStats] = None + allocated_mem_percent: Optional[UtilizationStats] = None + allocated_mem_value: Optional[UtilizationStats] = None + total_mem_value: Optional[float] = None @dataclass From a835dbc096dd5206b91449b3ccc60c069d288506 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 18 Jul 2025 12:41:13 -0700 Subject: [PATCH 273/457] [c10d][ez] Fix error message to reflect the correct API name (#158668) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158668 Approved by: https://github.com/VieEeEw --- torch/csrc/distributed/c10d/Backend.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 070cdb7234b4c..29a6fddd87907 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -111,7 +111,9 @@ class TORCH_API Backend : public torch::CustomClassHolder { TORCH_CHECK( false, c10::str( - "Backend ", getBackendName(), " does not implement endCoalescing")); + "Backend ", + getBackendName(), + " does not implement getBackendOptions.")); } virtual c10::intrusive_ptr broadcast( From 60b9b06a53e13100709d66df8a7555ed167d5a1e Mon Sep 17 00:00:00 2001 From: Grace Cheng Date: Fri, 18 Jul 2025 23:12:26 +0000 Subject: [PATCH 274/457] [caffe2] Fix Missing override in get_buffer of NCCLSymmetricMemory (#158597) Summary: Fix the error that occurs in the devarm environment when compiling with Clang: ``` caffe2/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu:97:20: error: 'get_buffer' overrides a member function but is not marked 'override' [-Werror,-Winconsistent-missing-override] 97 | virtual at::Tensor get_buffer(int | ^ caffe2/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp:56:20: note: overridden virtual function is here 56 | virtual at::Tensor get_buffer(int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storage_offset) = 0; | ^ 1 error generated. ``` Test Plan: See D78520305 Rollback Plan: Differential Revision: D78517953 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158597 Approved by: https://github.com/janeyx99 --- torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu index 4f69c49743864..55695ca27c8ec 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -98,7 +98,7 @@ class NCCLSymmetricMemory : public SymmetricMemory { int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, - int64_t storage_offset) { + int64_t storage_offset) override { // TODO: deduplicate const size_t numel = std::accumulate( sizes.begin(), From 15ef4f28df0a14e9f0d55a57a4e2db415a303be7 Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Fri, 18 Jul 2025 23:24:21 +0000 Subject: [PATCH 275/457] Fused RMSNorm implementation (#153666) Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/albanD --- .../functorch/BatchRulesDecompositions.cpp | 1 + .../src/ATen/native/cuda/layer_norm_kernel.cu | 590 +++++++++++++----- aten/src/ATen/native/layer_norm.cpp | 77 ++- aten/src/ATen/native/layer_norm.h | 6 + .../src/ATen/native/mps/operations/RMSNorm.mm | 13 +- aten/src/ATen/native/native_functions.yaml | 8 +- ...asDecompTest.test_has_decomposition.expect | 1 - .../check_forward_backward_compatibility.py | 2 + test/test_decomp.py | 29 +- tools/autograd/derivatives.yaml | 5 + torch/_decomp/__init__.py | 1 + torch/_decomp/decompositions.py | 75 +++ torch/csrc/autograd/FunctionsManual.cpp | 189 ++++++ torch/csrc/autograd/FunctionsManual.h | 23 + .../aoti_torch/generated/c_shim_cpu.h | 1 + .../aoti_torch/generated/c_shim_cuda.h | 1 + .../aoti_torch/generated/c_shim_mps.h | 2 +- .../aoti_torch/generated/c_shim_xpu.h | 1 + torch/overrides.py | 1 + 19 files changed, 843 insertions(+), 183 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 4b66b30b62e7f..d58d436c511d1 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); + m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index bdb169e26b142..f765b515cd0bc 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -50,7 +50,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,12 +84,17 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + if constexpr (!rms_norm){ + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + } else { + rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); + } + } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -103,11 +108,15 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; + if constexpr (!rms_norm){ + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; + } else { + Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; + } } } @@ -119,40 +128,48 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + if constexpr (!rms_norm){ + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + } else{ + return {0.f, curr_sum.sigma2 + val * val, 0}; + } } -__device__ +template __device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + if constexpr (!rms_norm){ + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; } else { - mean = U(0); - sigma2 = U(0); + return {0.f, dataB.sigma2 + dataA.sigma2, 0}; } - return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -171,14 +188,13 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), - WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -199,7 +215,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -216,7 +232,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -231,7 +247,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -254,34 +270,48 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + if constexpr (!rms_norm){ + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } else { + out.val[ii] = rstd_val * static_cast(data.val[ii]); + } } } Y_vec[i] = out; } if (thrx == 0) { - mean[i1] = wd.mean; + if constexpr (!rms_norm){ + mean[i1] = wd.mean; + } rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -296,7 +326,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -306,11 +336,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -321,7 +351,10 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - const T_ACC mean_val = mean[i1]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[i1]; + } const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -337,26 +370,39 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } + } + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); } - - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - buf[0] = stats_x1; + if constexpr (!rms_norm){ + buf[0] = stats_x1; + } buf[1] = stats_x2; } __syncthreads(); - stats_x1 = buf[0]; + if constexpr (!rms_norm){ + stats_x1 = buf[0]; + } stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -367,15 +413,20 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } + f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -387,7 +438,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -396,7 +447,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -409,7 +460,10 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - const T_ACC mean_val = mean[bIdx]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[bIdx]; + } const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -441,8 +495,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } @@ -451,19 +509,29 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else{ + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } // Reduction in Shared Memory - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + } stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - reduce_buf[0] = stats_x1; + if constexpr (!rms_norm){ + reduce_buf[0] = stats_x1; + } reduce_buf[1] = stats_x2; } __syncthreads(); - stats_x1 = reduce_buf[0]; + if constexpr (!rms_norm){ + stats_x1 = reduce_buf[0]; + } stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -485,8 +553,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -501,15 +573,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -525,17 +601,25 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + if constexpr (!rms_norm){ + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + } else { + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index])) * static_cast(rstd[i]); + } } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - db[j] = sum2; + if constexpr (!rms_norm){ + db[j] = sum2; + } } } } @@ -545,7 +629,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -569,7 +654,9 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; + if constexpr (!rms_norm){ + warp_mean = mean[mean_index + lane_id]; + } warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -596,10 +683,14 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; + if constexpr (!rms_norm){ + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } else{ + dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } } } @@ -608,7 +699,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -629,10 +721,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -654,7 +746,8 @@ template __global__ void @@ -679,7 +772,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -687,11 +780,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -706,7 +799,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db) { + if (db && !rms_norm) { db[thread_y * N + thread_x] = db_sum; } } @@ -752,7 +845,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db) { + if (db && !rms_norm) { db[out_index] = reg_db; } } @@ -763,7 +856,8 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction, +bool rms_norm> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -779,7 +873,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -790,7 +884,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -806,7 +900,7 @@ if (aligned_grid) { template +int rows_per_block_y, bool rms_norm> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -829,16 +923,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -876,19 +970,21 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined()) { + if (dbeta->defined() && !rms_norm) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); + if constexpr (!rms_norm){ + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); + } } } else { // We are in the normal case where M is not that large. @@ -896,18 +992,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -936,7 +1032,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -958,7 +1054,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -968,7 +1064,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -987,7 +1083,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = mean->data_ptr(); + T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1002,14 +1098,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1037,7 +1133,29 @@ void LayerNormKernelImpl( }); } -template __device__ +void RmsNormKernelImpl( + const Tensor& X, + const Tensor& gamma, + int64_t M, + int64_t N, + double eps, + Tensor* Y, + Tensor* rstd) { +AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormKernelImpl", + [&]() { + using acc_t = acc_type; + // rms_norm = true + LayerNormKernelImplInternal( + // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True + X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); + }); +} + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1055,7 +1173,10 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1080,7 +1201,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1098,7 +1219,11 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1114,7 +1239,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1140,9 +1265,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1181,7 +1306,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1206,7 +1331,9 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - sum_beta += part_grad_beta_ptr[warp_offset*N]; + if constexpr (!rms_norm){ + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } } } @@ -1224,7 +1351,9 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx+nbsize3]; + if constexpr (!rms_norm){ + sum_beta += buf[read_idx+nbsize3]; + } } __syncthreads(); } @@ -1235,12 +1364,14 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - grad_beta[i2] = sum_beta; + if constexpr (!rms_norm){ + grad_beta[i2] = sum_beta; + } } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1254,7 +1385,10 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = mean[i1]; + T_ACC c_mean = 0; + if constexpr (!rms_norm){ + c_mean = mean[i1]; + } const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1267,21 +1401,31 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + if constexpr (!rms_norm){ + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1292,25 +1436,33 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2*wrt_i] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*wrt_i] = sum_loss1; + } buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2*read_i]; + if constexpr (!rms_norm){ + sum_loss1 += buf[2*read_i]; + } sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - buf[2*threadIdx.x] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*threadIdx.x] = sum_loss1; + } buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - sum_loss1 = buf[2*threadIdx.x]; + if constexpr (!rms_norm){ + sum_loss1 = buf[2*threadIdx.x]; + } sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1323,8 +1475,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1333,8 +1489,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1344,7 +1504,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1358,7 +1518,9 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - TORCH_CHECK(mean.numel() == M); + if constexpr (!rms_norm){ + TORCH_CHECK(mean.numel() == M); + } TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1384,7 +1546,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1396,7 +1558,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1410,13 +1572,12 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); - if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1432,7 +1593,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1456,7 +1617,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1470,7 +1631,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1480,7 +1641,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1508,8 +1669,29 @@ void LayerNormBackwardKernelImpl( }); } +void RMSNormBackwardKernelImpl( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor* dX, + Tensor* dgamma) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormBackwardKernelImpl", + [&]() { + LayerNormBackwardKernelImplInternal( + dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); + }); +} + } // namespace + std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1638,6 +1820,108 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } +/* RMSNorm is implemented by reusing layer_norm's kernels */ +std::tuple _fused_rms_norm_cuda( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps){ + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + double eps_val = eps.value_or(std::numeric_limits::epsilon()); + + Tensor Y = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); + Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); + + if (M > 0) { + RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); + } + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (const auto idx: c10::irange(axis)) { + stat_shape.push_back(input_shape[idx]); + } + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { + stat_shape.push_back(1); + } + + rstd = rstd.view(stat_shape); + + return std::make_tuple(std::move(Y), std::move(rstd)); +} + + +std::tuple _fused_rms_norm_backward_cuda( + const Tensor& dY, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt /* optional */, + std::array grad_input_mask) { + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + Tensor dX; + Tensor dgamma; + if (grad_input_mask[0]) { + dX = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::native::zeros_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (M > 0 && N > 0) { + RMSNormBackwardKernelImpl( + dY, *X, rstd, *gamma, M, N, &dX, &dgamma); + } + return std::make_tuple(std::move(dX), std::move(dgamma)); +} + REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index da6bb5fec39e8..207f092a676a7 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,30 +261,11 @@ std::tuple math_native_layer_norm( return outputs; } -Tensor rms_norm_symint( +std::tuple rms_norm_composite( const Tensor& input, - c10::SymIntArrayRef normalized_shape, + IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - -#ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_nested = input.is_nested() || weight.is_nested(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); - const bool is_input_fp = isFloatingType(input.scalar_type()); - const bool is_weight_fp = isFloatingType(weight.scalar_type()); - - if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); - } - } -#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -321,10 +302,60 @@ Tensor rms_norm_symint( upcasted_result = upcasted_result.mul(weight_opt.value()); } - return upcasted_result; + // if nested do not make contiguous + if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); }); + return std::make_tuple( + std::get<0>(result).type_as(input), // Cast normalized result to original input type + std::get<1>(result) // rsqrt_val + ); +} + - return result.type_as(input); +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + const std::optional eps) { + + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + + // composite fallback for channels last + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + // composite fallback for complex datatypes + if(input.is_complex()){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + #ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + if (!(GradMode::is_enabled() && any_inputs_require_grad)) { + return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + } + + if (input.device().type() == DeviceType::MPS){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + #endif + + return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); } + } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0181f35fd6ed4..0debe942dd0a6 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,6 +106,12 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +std::tuple rms_norm_composite( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 71128297d5bfc..7948b5acd8e93 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,7 +19,14 @@ #include #endif -Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { +std::tuple _fused_rms_norm_mps(const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt, + const std::optional eps) { + const Tensor weight = weight_opt.value().contiguous(); + const int64_t normalized_ndim = normalized_shape.size(); + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c } }); - return output; + return std::make_tuple(output, Tensor()); } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 79b7e07e2284b..ce13e03fb9f6c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3314,9 +3314,15 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor +- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) dispatch: + CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + CompositeImplicitAutograd: rms_norm_composite + +- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 042959c22cd4a..a590713ad0f83 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -374,7 +374,6 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional -aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index d6cf2df4343ff..5a962dfa57c05 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,6 +139,8 @@ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), + # Previously MPS_only did not support backward + ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_decomp.py b/test/test_decomp.py index 5d641e32e422e..dcd6e69af997c 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import tf32_off +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,6 +1226,33 @@ def f(x, w, b): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) + @onlyCUDA + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_cuda(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e2419aab268b1..f0349c2484b61 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,6 +1267,11 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") +- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" + result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) + result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) + - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abb94b109cc0c..8e9796d2f7c1b 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f93a0bf84fb4b..832928ebf8aee 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm_backward.default) +def _fused_rms_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + rstd: Tensor, + weight: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + + grad_out_cast = grad_out.to( + computation_dtype, memory_format=torch.contiguous_format + ) + input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) + weight_cast = ( + weight.to(computation_dtype, memory_format=torch.contiguous_format) + if weight is not None + else None + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + ) + + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + + x_hat = input_cast * rstd + + if output_mask[0]: + sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) + d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd + + if output_mask[1] and weight_cast is not None: + d_weight_full_shape = grad_out_cast * x_hat + if len(outer_dim_indices) > 0: + d_weight = torch.sum( + d_weight_full_shape, dim=outer_dim_indices, keepdim=False + ) + else: + d_weight = d_weight_full_shape + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + ) + + def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 908a980cfee9c..8e13d4267edb5 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5023,6 +5023,103 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int normalized_ndim = normalized_shape.size(); + const int axis = input_ndim - normalized_ndim; + + int64_t N_rms = 1; + for (int i = 0; i < normalized_ndim; ++i) { + N_rms *= input_shape[axis + i]; + } + + Tensor dX; + Tensor dgamma; + + std::vector rstd_view_shape = rstd.sizes().vec(); + for (int i = 0; + i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); + ++i) { + rstd_view_shape.push_back(1); + } + Tensor rstd_broadcast = rstd.view(rstd_view_shape); + Tensor rstd_pow3 = rstd_broadcast.pow(3); + Tensor grad_x_hat; + + if (dY.defined()) { + if (weight.defined()) { + grad_x_hat = dY * weight; + } else { + grad_x_hat = dY; + } + } + + if (grad_input_mask[0]) { + Tensor dX_from_dY_path; + Tensor dX_from_drstd_path; + + std::vector inner_sum_dims; + inner_sum_dims.reserve(normalized_ndim); + for (int i = 0; i < normalized_ndim; ++i) { + inner_sum_dims.push_back(axis + i); + } + + if (dY.defined() && grad_x_hat.defined()) { + Tensor sum_input_times_grad_x_hat = + sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); + dX_from_dY_path = rstd_broadcast * grad_x_hat - + (input * rstd_pow3 / static_cast(N_rms)) * + sum_input_times_grad_x_hat; + } + + if (drstd.defined()) { + Tensor drstd_broadcast = drstd.view(rstd_view_shape); + dX_from_drstd_path = + -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; + } + + if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { + dX = dX_from_dY_path + dX_from_drstd_path; + } else if (dX_from_dY_path.defined()) { + dX = dX_from_dY_path; + } else if (dX_from_drstd_path.defined()) { + dX = dX_from_drstd_path; + } + } + + if (grad_input_mask[1] && weight.defined()) { + if (dY.defined()) { + Tensor x_hat = input * rstd_broadcast; + Tensor dgamma_full_shape = dY * x_hat; + + if (axis > 0) { + std::vector outer_sum_dims; + outer_sum_dims.reserve(axis); + for (int i = 0; i < axis; ++i) { + outer_sum_dims.push_back(i); + } + dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); + } else { + dgamma = dgamma_full_shape; + } + } + } + + return std::make_tuple(dX, dgamma); +} + std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + + Tensor result_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + result_t = (input_t)*rstd_p + (input_p)*rstd_t; + } else { + result_t = input_t * rstd_p; + auto temp = input_p * rstd_t; + result_t += temp; + } + + std::optional result_p = std::nullopt; + if (weight_p.defined()) { + result_p = std::optional(input_p * rstd_p); + } + + return _affine_jvp( + result_p, + result_t, + weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, + weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, + Tensor()); +} + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + return rstd_t; +} + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0b659973ec345..96864e165a95a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -826,6 +826,15 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask); + std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -965,6 +974,20 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2aa09cb802ecd..aced2b2f539de 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index e0607f984b3d0..92d30ded855f8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index a5d654c518840..c76ee685c25da 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 243bfb5fc87aa..6fc51bd0c8f8d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,6 +13,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/overrides.py b/torch/overrides.py index f29ffe57e36a6..2e696b2d96e4d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1, From 36bddcd18c3f42b8fd1f4547adaa5c10a2acb6fc Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 18 Jul 2025 10:37:19 -0700 Subject: [PATCH 276/457] [DTensor] Fix default_strategy and rename for clarity (#158490) Fixes several bugs in the original. - foremost, fixes a serious bug where we returned incorrect strategies by mixing input_specs that were frozen from select_strategy.strategies[0] with output_specs that varied across select_strategy.strategies[0..N] (e.g. we could create a nonsense strategy like input:Shard(0) output(Replicate) for an op like clone - fixes the redistribute costs: they should not actually be 0, they should be the cost of redistributing our single input from another strategy to the current strategy, in our list of output strategies - adds a note, wondering if we should have just literally returned the input strategy instead of creating this new object - Currently, using default_strategy is incorrect becuase it maps 'self' tensor's strategies directly onto 'src' tensor without accounting for the fact that copy_ supports broadcasting a smaller rank tensor into a larger one. Separates out copy_ op from default strategy, adds missing test case, but does not fix the underlying issue with copy_, leaves that for future PR Renames to `propagate_single_input_strategy` since that's more descriptive Pull Request resolved: https://github.com/pytorch/pytorch/pull/158490 Approved by: https://github.com/wanchaol, https://github.com/XilunWu --- test/distributed/tensor/test_tensor_ops.py | 32 ++++++ torch/distributed/tensor/_ops/_tensor_ops.py | 109 +++++++++++++------ 2 files changed, 108 insertions(+), 33 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 9be582952f367..9140d2f5aae13 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -53,6 +53,38 @@ def test_clone(self): self.assertFalse(cloned_mat is mat) self.assertEqual(cloned_mat.to_local(), mat.to_local()) + @with_comms + def test_copy_(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + src_specs = [[Replicate()], [Shard(0)]] + src_tensor = torch.randn((12, 12)) + + dst_tensor = torch.zeros(12, 12) + dst_specs = [[Replicate()], [Shard(0)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # @pytest.mark.xfail + # @with_comms + # def test_copy_broadcast(self): + # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # src_specs = [[Replicate()], [Shard(0)]] + # src_tensor = torch.randn((12,)) + + # dst_tensor = torch.zeros(12, 12) + # dst_specs = [[Replicate()], [Shard(1)]] + # for dst_spec, src_spec in zip(dst_specs, src_specs): + # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case + # dst_dtensor.copy_(src_dtensor) + # dst_tensor.copy_(src_tensor) + # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + @with_comms def test_contiguous(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 9bdfc90d145d4..262631eef8e4e 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -39,55 +39,98 @@ aten = torch.ops.aten -def default_strategy(op_schema: OpSchema) -> StrategyType: - # Default strategy by default just propagate the first input strategy - select_strategy = op_schema.args_schema[0] - assert isinstance(select_strategy, OpStrategy) - # we create new DTensorSpecs even for default strategy to assure that - # the tensor metas are distinct between the arguments and outputs - input_specs = [] - redistribute_cost = [] - for i in op_schema.args_schema: - input_specs.append( - DTensorSpec( - mesh=select_strategy.mesh, - placements=select_strategy.strategies[0].output_spec.placements, - tensor_meta=select_strategy.strategies[0].output_spec.tensor_meta, +def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: + # For ops with a single tensor input, we perform a 1:1 mapping such that + # for each strategy that the input supports, we create a corresponding strategy. + # Note: this may be a complete waste of work, because it should be equivalent to + # `return first_input_strategy` (unless creating a deep copy is important for some reason) + assert len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) == 1, ( + "propagate_single_input_strategy only works for single-tensor-input ops" + ) + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + ], ) - ) - redistribute_cost.append([0.0] * len(select_strategy.strategies)) - - default_strategy = [ - OpSpec( - output_specs=DTensorSpec( - mesh=select_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - for strategy in select_strategy.strategies - ] - return OpStrategy(default_strategy) + for strategy in first_input_strategy.strategies + ] + ) register_op_strategy( [ aten.clone.default, aten.contiguous.default, - aten.copy_.default, aten.detach.default, aten.fill_.Scalar, aten.view.dtype, aten.zero_.default, ] -)(default_strategy) +)(propagate_single_input_strategy) register_op_strategy( aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(default_strategy) +)(propagate_single_input_strategy) + + +@register_op_strategy(aten.copy_.default) +def copy_strategy(op_schema: OpSchema) -> StrategyType: + # TODO: this strategy is incorrect for copy_ in the case that src tensor + # is smaller rank than self tensor. It is possible to select a strategy from self tensor + # that is invalid for dst tensor. + # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we + # may broadcast a new dim to the left or right of 0 when copying. + # + # For now, I just keep copy working essentially the way it was before this PR, + # but split it out so it can be handled separately in the future. + num_tensor_args = 2 + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + for _ in range(num_tensor_args) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + for _ in range(num_tensor_args) + ], + ) + for strategy in first_input_strategy.strategies + ] + ) @register_op_strategy( From a3aacd6cb2ee581b7bdb45cdbe6f70695c04f4ab Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 18 Jul 2025 10:37:19 -0700 Subject: [PATCH 277/457] [DTensor] fix copy_ strategy (#158538) The previous strategy directly used 'self' input strategy for 'src' input. The fixed strategy correctly maps the self dim to src dim so that it works even if the src input is broadcast. E.g. for this program, broadcasting will occur on dims 0,1,3 of self. ``` self = torch.ones((2,3,4,5)) src = torch.ones((4,1)) self.copy_(src) ``` These are the correct sharding combinations: | self | src | |-------|------| | Shard(0) | Replicate() | | Shard(1) | Replicate() | | Shard(2) | Shard(0) | | Shard(3) | Shard(1) | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158538 Approved by: https://github.com/zpcore, https://github.com/XilunWu, https://github.com/wanchaol ghstack dependencies: #158490 --- test/distributed/tensor/test_tensor_ops.py | 44 ++++++++------- torch/distributed/tensor/_ops/_tensor_ops.py | 56 ++++++-------------- 2 files changed, 42 insertions(+), 58 deletions(-) diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 9140d2f5aae13..d62da27d43393 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -56,10 +56,11 @@ def test_clone(self): @with_comms def test_copy_(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - src_specs = [[Replicate()], [Shard(0)]] - src_tensor = torch.randn((12, 12)) + # basic test + src_tensor = torch.randn((12, 12)) dst_tensor = torch.zeros(12, 12) + src_specs = [[Replicate()], [Shard(0)]] dst_specs = [[Replicate()], [Shard(0)]] for dst_spec, src_spec in zip(dst_specs, src_specs): src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) @@ -68,22 +69,29 @@ def test_copy_(self): dst_tensor.copy_(src_tensor) self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) - # @pytest.mark.xfail - # @with_comms - # def test_copy_broadcast(self): - # device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) - # src_specs = [[Replicate()], [Shard(0)]] - # src_tensor = torch.randn((12,)) - - # dst_tensor = torch.zeros(12, 12) - # dst_specs = [[Replicate()], [Shard(1)]] - # for dst_spec, src_spec in zip(dst_specs, src_specs): - # src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) - # dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) - # # perform a broadcasted copy from Shard(0) to Shard(1) for the worst case - # dst_dtensor.copy_(src_dtensor) - # dst_tensor.copy_(src_tensor) - # self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + # simple broadcasting + src_tensor = torch.randn((128,)) + dst_tensor = torch.zeros(128, 128) + src_specs = [[Replicate()], [Shard(0)]] + dst_specs = [[Replicate()], [Shard(1)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen + src_tensor = torch.randn((64, 1)) + dst_tensor = torch.zeros(16, 32, 64, 128) + src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]] + dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) @with_comms def test_contiguous(self): diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 262631eef8e4e..3b6b8c33cdbd8 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -35,6 +35,8 @@ Shard, ) +from ._pointwise_ops import pointwise_strategy + aten = torch.ops.aten @@ -91,46 +93,20 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) )(propagate_single_input_strategy) - -@register_op_strategy(aten.copy_.default) -def copy_strategy(op_schema: OpSchema) -> StrategyType: - # TODO: this strategy is incorrect for copy_ in the case that src tensor - # is smaller rank than self tensor. It is possible to select a strategy from self tensor - # that is invalid for dst tensor. - # It is also problematic to assume that shard(0) on src maps to shard(0) on self, since we - # may broadcast a new dim to the left or right of 0 when copying. - # - # For now, I just keep copy working essentially the way it was before this PR, - # but split it out so it can be handled separately in the future. - num_tensor_args = 2 - first_input_strategy = op_schema.args_schema[0] - assert isinstance(first_input_strategy, OpStrategy) - return OpStrategy( - [ - OpSpec( - output_specs=DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=[ - DTensorSpec( - mesh=first_input_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ) - for _ in range(num_tensor_args) - ], - redistribute_cost=[ - generate_redistribute_costs( - first_input_strategy, strategy.output_spec - ) - for _ in range(num_tensor_args) - ], - ) - for strategy in first_input_strategy.strategies - ] - ) +# copy_ is actually a pointwise op with broadcasting, so reuse the pointwise strategy, which takes care of these +# requirements. +# +# Following torch broadcasting semantics (https://docs.pytorch.org/docs/stable/notes/broadcasting.html) +# - self can not change shape as a result of broadcasting since this is an inplace op +# - src can broadcast, but when it does it always does so from the trailing end +# e.g. the last dim of 'src' must match up with the last dim of 'self' +# +# DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input. +# In practice, what this means is +# - our output strategies should map 1:1 to our 'self' input strategies +# - our 'src' input may be redistributed to match up with the 'self' input, with the caveat of adjusting for +# broadcasting dim +register_op_strategy(aten.copy_.default)(pointwise_strategy) @register_op_strategy( From d42c40976727fed4c9908d4194f26917d0a3da66 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 19 Jul 2025 00:06:36 +0000 Subject: [PATCH 278/457] [AOTI] windows package load dev (#158671) changes: 1. add extract file fail handler for Windows develop. 2. normalize more file paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158671 Approved by: https://github.com/angelayi --- .../inductor/aoti_package/model_package_loader.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 629dc8cb2ae80..127969c0318ff 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -471,7 +471,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << found_filenames[1]; } - temp_dir_ = create_temp_dir(); + temp_dir_ = normalize_path_separator(create_temp_dir()); std::string so_filename; std::string cpp_filename; @@ -504,6 +504,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } + output_path_str = normalize_path_separator(output_path_str); + LOG(INFO) << "Extract file: " << filename_str << " to " << output_path_str; @@ -522,8 +524,12 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - mz_zip_reader_extract_file_to_file( + mz_bool b_extract = mz_zip_reader_extract_file_to_file( &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); + if (b_extract == MZ_FALSE) { + throw std::runtime_error(fmt::format( + "Failed to extract file {} to {}", filename_str, output_path_str)); + } // Save the file for bookkeeping size_t extension_idx = output_path_str.find_last_of('.'); From 5b40f6581eac8a2e92af8dd986df7c22ad4584ce Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 19 Jul 2025 00:32:04 +0000 Subject: [PATCH 279/457] Revert "Add warning about removed sm50 and sm60 arches (#158301)" This reverts commit fb731fe371cb1b5bf95de84b19c213590526acb2. Reverted https://github.com/pytorch/pytorch/pull/158301 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158301#issuecomment-3091307023)) --- torch/cuda/__init__.py | 56 +++++++++++------------------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 6a8fc7dfb12ef..fd88e199a7a15 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -244,25 +244,21 @@ def _extract_arch_version(arch_string: str) -> int: def _check_capability(): - incompatible_gpu_warn = """ + incorrect_binary_warn = """ + Found GPU%d %s which requires CUDA_VERSION >= %d to + work properly, but your PyTorch was compiled + with CUDA_VERSION %d. Please install the correct PyTorch binary + using instructions from https://pytorch.org + """ # noqa: F841 + + old_gpu_warn = """ Found GPU%d %s which is of cuda capability %d.%d. - Minimum and Maximum cuda capability supported by this version of PyTorch is - (%d.%d) - (%d.%d) + PyTorch no longer supports this GPU because it is too old. + The minimum cuda capability supported by this library is %d.%d. """ - matched_cuda_warn = """ - Please install PyTorch with a following CUDA - configurations: {} following instructions at - https://pytorch.org/get-started/locally/ - """ - - # Binary CUDA_ARCHES SUPPORTED by PyTorch - CUDA_ARCHES_SUPPORTED = { - "12.6": {"min": 50, "max": 90}, - "12.8": {"min": 70, "max": 120}, - "12.9": {"min": 70, "max": 120}, - } if torch.version.cuda is not None: # on ROCm we don't want this check + CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841 for d in range(device_count()): capability = get_device_capability(d) major = capability[0] @@ -271,35 +267,13 @@ def _check_capability(): current_arch = major * 10 + minor min_arch = min( (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), - default=50, + default=35, ) - max_arch = max( - (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), - default=50, - ) - if current_arch < min_arch or current_arch > max_arch: + if current_arch < min_arch: warnings.warn( - incompatible_gpu_warn - % ( - d, - name, - major, - minor, - min_arch // 10, - min_arch % 10, - max_arch // 10, - max_arch % 10, - ) + old_gpu_warn + % (d, name, major, minor, min_arch // 10, min_arch % 10) ) - matched_arches = "" - for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): - if ( - current_arch >= arch_info["min"] - and current_arch <= arch_info["max"] - ): - matched_arches += f" {arch}" - if matched_arches != "": - warnings.warn(matched_cuda_warn.format(matched_arches)) def _check_cubins(): From c2c88846a9c75185660f3b2a8b72c3aa2f8ae3dc Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 19 Jul 2025 00:45:31 +0000 Subject: [PATCH 280/457] Revert "[Easy] Show some clear error when torch.ops.load_library fails. (#157524)" This reverts commit 555f3562541992b66a550eca8e8740884b1247f8. Reverted https://github.com/pytorch/pytorch/pull/157524 on behalf of https://github.com/wdvr due to reverting for now to reopen the discussion ([comment](https://github.com/pytorch/pytorch/pull/157524#issuecomment-3091317252)) --- torch/_ops.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch/_ops.py b/torch/_ops.py index e51343cff972c..4ccb9d7be6550 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1483,10 +1483,7 @@ def load_library(self, path): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom # operators with the JIT. - try: - ctypes.CDLL(path) - except Exception as e: - raise RuntimeError(f"Could not load this library: {path}") from e + ctypes.CDLL(path) self.loaded_libraries.add(path) From 2c16eb9f3db0ba68520e5832d8bb6d3d875bdaeb Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Fri, 18 Jul 2025 09:56:29 -0700 Subject: [PATCH 281/457] [dynamo] Support more basic output types for `nonstrict_trace` (#157969) Fixes #157397 and improves the user-facing error message for remaining unsupported cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157969 Approved by: https://github.com/zou3519 --- test/dynamo/test_decorators.py | 45 +++++++++++++++++++++++++ torch/_dynamo/graph_break_registry.json | 10 ++++++ torch/_dynamo/variables/builder.py | 7 ++++ torch/_dynamo/variables/torch.py | 27 +++++++++++---- 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 70e1946c30969..3b29e5e961192 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -514,6 +514,23 @@ def fn(x, s): fn(x, State(41)) self.assertEqual(cnts.frame_count, 2) + def test_nonstrict_trace_int_and_float_output(self): + @torch._dynamo.nonstrict_trace + def trace_me(x): + torch._dynamo.graph_break() + return len(x.shape), 0.42 + + def fn(x): + n1, n2 = trace_me(x) + return x * n1 + n2 + + x = torch.randn(10) + opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager") + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_nonstrict_trace_tuple_and_sym_int_output(self): @torch._dynamo.nonstrict_trace def trace_me(x): @@ -719,6 +736,34 @@ def fn(x, y): except torch._dynamo.exc.Unsupported as e: self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) + def test_nonstrict_trace_custom_class_output_error(self): + class Point: + x: torch.Tensor + y: torch.Tensor + + def __init__(self, x, y): + self.x = x + self.y = y + + @torch._dynamo.nonstrict_trace + def trace_me(x): + torch._dynamo.graph_break() + return Point(x, x + 1) + + @torch.compile(fullgraph=True, backend="aot_eager") + def fn(x): + p = trace_me(x) + return p.x * p.y + + try: + x = torch.ones(10) + fn(x) + self.assertFalse(True) # must raise error before this + except torch._dynamo.exc.Unsupported as e: + self.assertIn( + "Unsupported output type for nonstrict_trace-ed function", str(e) + ) + def test_nonstrict_newly_constructed_trace_register_constant_type_error(self): class State: def __init__(self, n): diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 58085696b78be..0bbdd91b6ae2e 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2499,5 +2499,15 @@ "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." ] } + ], + "GB0251": [ + { + "Gb_type": "Unsupported output type for nonstrict_trace-ed function", + "Context": "Function: {fn.__name__}", + "Explanation": "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list) are allowed as output. The result of this call contains an unsupported type.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } \ No newline at end of file diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index c6b061f42b1bf..9c13267c25bf3 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -51,6 +51,7 @@ set_feature_use, ) from torch._guards import TracingContext +from torch._higher_order_ops.flat_apply import flat_apply from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode @@ -3002,6 +3003,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ): set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) + elif ( + isinstance(example_value, (int, float, bool)) + and proxy.node.target is flat_apply + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 72b2e3dc132f4..fc1f9646ffdf4 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1359,12 +1359,27 @@ def patched_fn(*args, **kwargs): # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate # the call and wrap output into a VariableTracker. proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) - out_vt = wrap_fx_proxy(tx, proxy) - # TODO support more output types - # Q: flat_apply will likely pytree_flatten the output for this, then - # how do we intercept the output before flatten, and wrap those? - # - Maybe we can have `flat_apply` return the output spec, so that - # Dynamo can unflatten and wrap the result. + try: + # TODO support more output types once `flat_apply` supports + # pytree-able output types. We can have Dynamo trace through an + # unflatten call (just like we traced through a flatten above) + # to rebuild the actual output VT. + out_vt = wrap_fx_proxy(tx, proxy) + except ( + # From `handle_traced_output`. + torch._dynamo.exc.Unsupported, + # From `flat_apply` assert on output type. + torch._dynamo.exc.TorchRuntimeError, + ): + unimplemented_v2( + gb_type="Unsupported output type for nonstrict_trace-ed function", + context=f"Function: {fn.__name__}", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" + " are allowed as output. The result of this call contains an unsupported type." + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) return out_vt From 2955acaed6a0f93f1f0913df3f840912392bc2ff Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Jul 2025 09:02:40 -0700 Subject: [PATCH 282/457] Clean up some unused build env variables (#158599) * Parameter build-with-debug isn't needed, it isn't even passed into Docker. Debug build is detected via the build environment name * AWS_DEFAULT_REGION is a leftover from ARC and isn't used anywhere in .ci/pytorch nor .github Signed-off-by: Huy Do Pull Request resolved: https://github.com/pytorch/pytorch/pull/158599 Approved by: https://github.com/cyyever, https://github.com/ZainRizvi ghstack dependencies: #158598 --- .github/workflows/_linux-build.yml | 11 ----------- .github/workflows/periodic.yml | 2 -- 2 files changed, 13 deletions(-) diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 1f1146fcde1be..7ce9741528e5c 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -16,11 +16,6 @@ on: type: boolean default: true description: If set, upload generated build artifacts. - build-with-debug: - required: false - type: boolean - default: false - description: If set, build in debug mode. sync-tag: required: false type: string @@ -87,7 +82,6 @@ on: required: false type: number default: 1 - allow-reuse-old-whl: description: | If set, the build try to pull an old wheel from s3 that was built on a @@ -106,7 +100,6 @@ on: description: | FB app token to write to scribe endpoint - outputs: docker-image: value: ${{ jobs.build.outputs.docker-image }} @@ -247,8 +240,6 @@ jobs: env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - # TODO duplicated - AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs @@ -260,7 +251,6 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} DOCKER_IMAGE_S390X: ${{ inputs.docker-image-name }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} - DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} @@ -295,7 +285,6 @@ jobs: container_name=$(docker run \ -e BUILD_ENVIRONMENT \ -e MAX_JOBS="$(nproc --ignore=2)" \ - -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ -e SHA1 \ -e BRANCH \ diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 643d40e4d381b..976fb241c99f9 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -157,7 +157,6 @@ jobs: { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} - build-with-debug: false secrets: inherit linux-jammy-cuda12_8-py3_9-gcc9-test: @@ -178,7 +177,6 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 - build-with-debug: true test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, From a74109415943d56a0b681cbd4cf772ca07208818 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Jul 2025 09:02:41 -0700 Subject: [PATCH 283/457] Build domain libraries on the build job (#158600) By setting the name of the domain libraries to build via `BUILD_ADDITIONAL_PACKAGES` environment variable, the build job will build them and make them available as artifacts in the same way as the PyTorch CI wheel. To ensure that this doesn't break CI, the test job will still build them as usual if the wheels are not there. Building dependencies like FBGEMM on the test job is bad, especially for GPU jobs, because it leave the GPU resource idle Fixes https://github.com/pytorch/pytorch/issues/152024 Signed-off-by: Huy Do Pull Request resolved: https://github.com/pytorch/pytorch/pull/158600 Approved by: https://github.com/yangw-dev ghstack dependencies: #158598, #158599 --- .ci/pytorch/build.sh | 18 ++++++ .ci/pytorch/common_utils.sh | 94 ++++++++++++++++++++++-------- .ci/pytorch/test.sh | 12 ++-- .github/workflows/_linux-build.yml | 9 +++ 4 files changed, 100 insertions(+), 33 deletions(-) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 07bf2037f430d..9e5d2c4675eed 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -309,6 +309,24 @@ else fi pip_install_whl "$(echo dist/*.whl)" + if [[ -n "${BUILD_ADDITIONAL_PACKAGES}" ]]; then + if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *vision* ]]; then + install_torchvision + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *audio* ]]; then + install_torchaudio + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES}" == *fbgemm* ]]; then + install_torchrec_and_fbgemm + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *torchao* ]]; then + install_torchao + fi + fi + if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then echo "Checking that xpu is compiled" pushd dist/ diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 3dbc2ece9e70b..69a5b7ad37951 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -78,6 +78,34 @@ function pip_install_whl() { fi } +function pip_build_and_install() { + local build_target=$1 + local wheel_dir=$2 + + local found_whl=0 + for file in "${wheel_dir}"/*.whl + do + if [[ -f "${file}" ]]; then + found_whl=1 + break + fi + done + + # Build the wheel if it doesn't exist + if [ "${found_whl}" == "0" ]; then + python3 -m pip wheel \ + --no-build-isolation \ + --no-deps \ + --no-use-pep517 \ + -w "${wheel_dir}" \ + "${build_target}" + fi + + for file in "${wheel_dir}"/*.whl + do + pip_install_whl "${file}" + done +} function pip_install() { # retry 3 times @@ -124,14 +152,7 @@ function get_pinned_commit() { function install_torchaudio() { local commit commit=$(get_pinned_commit audio) - if [[ "$1" == "cuda" ]]; then - # TODO: This is better to be passed as a parameter from _linux-test workflow - # so that it can be consistent with what is set in build - TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" - else - pip_install --no-use-pep517 "git+https://github.com/pytorch/audio.git@${commit}" - fi - + pip_build_and_install "git+https://github.com/pytorch/audio.git@${commit}" dist/audio } function install_torchtext() { @@ -139,8 +160,8 @@ function install_torchtext() { local text_commit data_commit=$(get_pinned_commit data) text_commit=$(get_pinned_commit text) - pip_install --no-use-pep517 "git+https://github.com/pytorch/data.git@${data_commit}" - pip_install --no-use-pep517 "git+https://github.com/pytorch/text.git@${text_commit}" + pip_build_and_install "git+https://github.com/pytorch/data.git@${data_commit}" dist/data + pip_build_and_install "git+https://github.com/pytorch/text.git@${text_commit}" dist/text } function install_torchvision() { @@ -153,7 +174,7 @@ function install_torchvision() { echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so fi - pip_install --no-use-pep517 "git+https://github.com/pytorch/vision.git@${commit}" + pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision if [ -n "${LD_PRELOAD}" ]; then LD_PRELOAD=${orig_preload} fi @@ -173,25 +194,48 @@ function install_torchrec_and_fbgemm() { if [[ "$BUILD_ENVIRONMENT" == *rocm* ]] ; then # install torchrec first because it installs fbgemm nightly on top of rocm fbgemm - pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_uninstall fbgemm-gpu-nightly pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm - git clone --recursive https://github.com/pytorch/fbgemm - pushd fbgemm/fbgemm_gpu - git checkout "${fbgemm_commit}" - python setup.py install \ - --package_variant=rocm \ - -DHIP_ROOT_DIR="${ROCM_PATH}" \ - -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ - -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" - popd + + local wheel_dir=dist/fbgemm_gpu + local found_whl=0 + for file in "${wheel_dir}"/*.whl + do + if [[ -f "${file}" ]]; then + found_whl=1 + break + fi + done + + # Build the wheel if it doesn't exist + if [ "${found_whl}" == "0" ]; then + git clone --recursive https://github.com/pytorch/fbgemm + pushd fbgemm/fbgemm_gpu + git checkout "${fbgemm_commit}" + python setup.py bdist_wheel \ + --package_variant=rocm \ + -DHIP_ROOT_DIR="${ROCM_PATH}" \ + -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ + -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" + popd + + # Save the wheel before cleaning up + mkdir -p dist/fbgemm_gpu + cp fbgemm/fbgemm_gpu/dist/*.whl dist/fbgemm_gpu + fi + + for file in "${wheel_dir}"/*.whl + do + pip_install_whl "${file}" + done + rm -rf fbgemm else - # See https://github.com/pytorch/pytorch/issues/106971 - CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" - pip_install --no-use-pep517 "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec + pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu fi } @@ -234,7 +278,7 @@ function checkout_install_torchbench() { function install_torchao() { local commit commit=$(get_pinned_commit torchao) - pip_install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${commit}" + pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao } function print_sccache_stats() { diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index d7d5947d2ce2c..ad6a48b2528e4 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1666,23 +1666,19 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then id=$((SHARD_NUMBER-1)) test_dynamo_benchmark timm_models "$id" elif [[ "${TEST_CONFIG}" == cachebench ]]; then - install_torchaudio cuda + install_torchaudio install_torchvision checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco PYTHONPATH=$(pwd)/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then - install_torchaudio cpu + install_torchaudio install_torchvision checkout_install_torchbench nanogpt PYTHONPATH=$(pwd)/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then - if [[ "${TEST_CONFIG}" == *cpu* ]]; then - install_torchaudio cpu - else - install_torchaudio cuda - fi + install_torchaudio install_torchvision - TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao + install_torchao id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index 7ce9741528e5c..5173425009f69 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -89,6 +89,13 @@ on: required: false type: boolean default: true + build-additional-packages: + description: | + If set, the build job will also builds these packages and saves their + wheels as artifacts + required: false + type: string + default: "" secrets: HUGGING_FACE_HUB_TOKEN: @@ -254,6 +261,7 @@ jobs: OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + BUILD_ADDITIONAL_PACKAGES: ${{ inputs.build-additional-packages }} run: | START_TIME=$(date +%s) if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then @@ -299,6 +307,7 @@ jobs: -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e USE_SPLIT_BUILD \ + -e BUILD_ADDITIONAL_PACKAGES \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ From 90b082e207bff79dd09d89cfef9be49de5c2ad83 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 17 Jul 2025 12:46:00 -0700 Subject: [PATCH 284/457] enable_caching_generated_triton_templates=True by default (#158592) Got some risk, but good to catch issues if there is any, easy to revert single flag flip. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158592 Approved by: https://github.com/eellison --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f1edeb21b4062..ef7de961149e5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -995,7 +995,7 @@ def decide_compile_threads() -> int: annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" # Enable caching codegen of triton templates. -enable_caching_generated_triton_templates: bool = False +enable_caching_generated_triton_templates: bool = True # Lookup table for overriding autotune configs based on hash of Triton source code autotune_lookup_table: dict[str, dict[str, Any]] = {} From ab557421a473993b6c7c841cc0d2ff490c718ea4 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Sat, 19 Jul 2025 02:51:24 +0000 Subject: [PATCH 285/457] [cca] [c10d] Refactor CUDAEventCache into separate files (#158616) Summary: Refactored CUDAEventCache from ProcessGroupNCCL.hpp/.cpp into dedicated header and implementation files for better code organization and maintainability. Split out CUDAEventCache into: - New header file: CUDAEventCache.hpp - New implementation file: CUDAEventCache.cpp - Updated build_variables.bzl to include the new file This change improves code maintainability, readability, and follows better code organization practices. --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Test Plan: Verified build with: ``` buck build //caffe2/test/distributed:c10d ``` --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158616 Approved by: https://github.com/fduwjj --- build_variables.bzl | 1 + test/cpp/c10d/ProcessGroupNCCLTest.cpp | 16 ++--- .../distributed/c10d/ProcessGroupNCCL.cpp | 61 +------------------ .../distributed/c10d/ProcessGroupNCCL.hpp | 18 +----- .../distributed/c10d/cuda/CUDAEventCache.cpp | 58 ++++++++++++++++++ .../distributed/c10d/cuda/CUDAEventCache.hpp | 29 +++++++++ 6 files changed, 99 insertions(+), 84 deletions(-) create mode 100644 torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp create mode 100644 torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 776b1f433fbd0..d633a29c5b634 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -738,6 +738,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCTracing.cpp", "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", + "torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp", "torch/csrc/distributed/c10d/cuda/utils.cpp", "torch/csrc/distributed/c10d/cuda/StreamBlock.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index a1360c8dd40fd..ac4ba4da01577 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -767,8 +767,8 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { } // Test that the CUDAEventCache can be used to create CUDA events and reuse. - auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event1 = c10d::CUDAEventCache::get(1)->create(true); + auto event2 = c10d::CUDAEventCache::get(1)->create(false); auto event1_ptr = event1.get(); auto event2_ptr = event2.get(); @@ -777,14 +777,14 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { event2 = nullptr; // Test that the CUDAEventCache is indeed reused. - auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true); - auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false); + auto event3 = c10d::CUDAEventCache::get(2)->create(true); + auto event4 = c10d::CUDAEventCache::get(2)->create(false); // The cache has been used up, new events should be created. - auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event5 = c10d::CUDAEventCache::get(1)->create(true); + auto event6 = c10d::CUDAEventCache::get(1)->create(false); // The cache has been used up, new events should be created. - auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event7 = c10d::CUDAEventCache::get(1)->create(true); + auto event8 = c10d::CUDAEventCache::get(1)->create(false); EXPECT_NE(event1_ptr, event3.get()); EXPECT_NE(event2_ptr, event4.get()); EXPECT_EQ(event1_ptr, event5.get()); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index a0c546a405f59..c0c98326690be 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -519,11 +519,9 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( // DEFAULT_FLAGS = cudaEventDisableTiming. if (cudaEventCacheEnabled) { ncclStartEvent_ = enableTiming - ? ProcessGroupNCCL::CUDAEventCache::get(device.index()) - ->create(enableTiming) + ? CUDAEventCache::get(device.index())->create(enableTiming) : nullptr; - ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index()) - ->create(enableTiming); + ncclEndEvent_ = CUDAEventCache::get(device.index())->create(enableTiming); } else { ncclStartEvent_ = enableTiming ? std::make_shared(cudaEventDefault) @@ -860,61 +858,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() { } } -ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; - -// CUDA event is used to record the start/end of one Work. -// Instead of let the CUDA event gets destroyed, we now reuse it after the Work -// has been erased from workMetaList_. -// This is to avoid the potential deadlock caused by CudaEventDestroy. -std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( - bool timing) { - // Register the deleter as a callback when the WorkNCCL object is destroyed. - // Each deleter keeps a ref count to the cache object, so that even when - // the thread that creates the cache is gone, the cache object won't be - // destroyed until all the events in the cache are destroyed (ref number drops - // to zero). - auto deleter = [cache = shared_from_this(), - timing](at::cuda::CUDAEvent* event) { - std::lock_guard lock(cache->cacheMutex_); - // We put the event back to the cache deque once the WorkNCCL object is - // destroyed. - cache->eventsArray_[timing ? 1 : 0].push_back(event); - }; - at::cuda::CUDAEvent* event = nullptr; - { - std::lock_guard lock(cacheMutex_); - auto& events = eventsArray_[timing ? 1 : 0]; - // If we still have events in the cache, we reuse it. Otherwise, we create a - // new one. - if (!events.empty()) { - event = events.front(); - events.pop_front(); - } else { - event = new at::cuda::CUDAEvent( - timing ? cudaEventDefault : cudaEventDisableTiming); - } - } - return std::shared_ptr(event, std::move(deleter)); -} - -std::shared_ptr ProcessGroupNCCL:: - CUDAEventCache::get(at::DeviceIndex device) { - // A per-thread singleton of device-to-CUDAEventCache map. - // Map is needed because events cannot be reused across devices. - // Per-thread ownership is needed to support multi-threaded case (instead of - // multi-process case). - static thread_local std:: - map> - cacheDeviceMap; - // Check if device has already been in the map, if not, add a new entry - auto it = cacheDeviceMap.find(device); - if (it == cacheDeviceMap.end()) { - cacheDeviceMap.emplace( - device, std::make_shared()); - } - return cacheDeviceMap[device]; -} - static std::atomic process_group_id = 0; constexpr const char* MULTI_DEVICE_ERROR_MSG = diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 810f8db9fd7d8..9d72207a4b79a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -503,23 +504,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { friend class ProcessGroupNCCL; }; - class CUDAEventCache - : public std::enable_shared_from_this { - public: - CUDAEventCache(); - std::shared_ptr create(bool timing); - static std::shared_ptr get( - at::DeviceIndex device); - - private: - std::mutex cacheMutex_; - // NOTE: We intentionally store raw pointers so that - // we do not attempt to destroy the event objects on process exit, - // because cuda may be gone. - std::array, 2> - eventsArray_; // 0 for timing=false, 1 for timing=true - }; - struct Options : Backend::Options { // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for // operations. This is only used when blockingWait_ is enabled. diff --git a/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp new file mode 100644 index 0000000000000..75208e92b4081 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp @@ -0,0 +1,58 @@ +#include +#include +#include + +namespace c10d { + +CUDAEventCache::CUDAEventCache() = default; + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr CUDAEventCache::create(bool timing) { + // Register the deleter as a callback when the WorkNCCL object is destroyed. + // Each deleter keeps a ref count to the cache object, so that even when + // the thread that creates the cache is gone, the cache object won't be + // destroyed until all the events in the cache are destroyed (ref number drops + // to zero). + auto deleter = [cache = shared_from_this(), + timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(cache->cacheMutex_); + // We put the event back to the cache deque once the WorkNCCL object is + // destroyed. + cache->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. + if (!events.empty()) { + event = events.front(); + events.pop_front(); + } else { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + } + return std::shared_ptr(event, std::move(deleter)); +} + +std::shared_ptr CUDAEventCache::get(at::DeviceIndex device) { + // A per-thread singleton of device-to-CUDAEventCache map. + // Map is needed because events cannot be reused across devices. + // Per-thread ownership is needed to support multi-threaded case (instead of + // multi-process case). + static thread_local std::map> + cacheDeviceMap; + // Check if device has already been in the map, if not, add a new entry + auto it = cacheDeviceMap.find(device); + if (it == cacheDeviceMap.end()) { + cacheDeviceMap.emplace(device, std::make_shared()); + } + return cacheDeviceMap[device]; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp new file mode 100644 index 0000000000000..5639c1f04fd76 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace c10d { + +class TORCH_API CUDAEventCache + : public std::enable_shared_from_this { + public: + CUDAEventCache(); + std::shared_ptr create(bool timing); + static std::shared_ptr get(at::DeviceIndex device); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionally store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::array, 2> + eventsArray_; // 0 for timing=false, 1 for timing=true +}; + +} // namespace c10d From 64dabb2cf5c4112c7c169fb76dabe9ab905c8e7c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 18 Jul 2025 15:06:02 -0700 Subject: [PATCH 286/457] only fail regressions>10% on pr_time benchmarks (#158577) Moving to a new framework, maintaitning the pr_time benchmark test right now is hard and often breaking. 1. only fail PRs >10% regressions. 2. post monitor with pr_time benchmarks dashboard (oncall), and update expected results (frequently or on big changes) (supposed to already be doing https://www.internalfb.com/unidash/dashboard/pt2_diff_time_metrics) 3. setting up some one detections detectors warnings that would be triggered at regressions and notify internally post land https://www.internalfb.com/monitoring/detector/1140915271179237 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158577 Approved by: https://github.com/xmfan, https://github.com/janeyx99 --- .../pr_time_benchmarks/expected_results.csv | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index edc9d0f73d161..c0d676f885109 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,89 +1,89 @@ -add_loop_eager,compile_time_instruction_count,3017000000,0.015 +add_loop_eager,compile_time_instruction_count,3070000000,0.10 -add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10 -add_loop_inductor,compile_time_instruction_count,29490000000,0.015 +add_loop_inductor,compile_time_instruction_count,30280000000,0.10 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10 -add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000000,0.2 -update_hint_regression,compile_time_instruction_count,1673000000,0.02 +update_hint_regression,compile_time_instruction_count,1719000000,0.10 -sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 +sum_floordiv_regression,compile_time_instruction_count,966100000,0.10 -symint_sum,compile_time_instruction_count,3166000000,0.015 +symint_sum,compile_time_instruction_count,3237000000,0.10 -symint_sum_loop,compile_time_instruction_count,4202000000,0.015 +symint_sum_loop,compile_time_instruction_count,4299000000,0.10 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10 -mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10 -basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 +basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10 -basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 +basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10 From fac0be7b9c80f20bbff1e813225dcbced7ff4d31 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Fri, 18 Jul 2025 18:46:35 +0000 Subject: [PATCH 287/457] [async-TP] Turn asserts back into silent skips (#158572) https://github.com/pytorch/pytorch/pull/149946 modified some checks that verify whether async-TP is "applicable" to a given collective operation in a graph. Before, the pattern-mathcing+replacement would just be skipped, but now these are asserts that fail and raise. This is causing concrete issues in some graphs where 2-dimensional device meshes are being used (e.g., TP + CP) but only one dimension has symm-mem enabled. See #158569. This PR is turning these asserts back into harmless early-exits. Note that this only needed to be done for reduce-scatters, as it was already the case for all-gathers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158572 Approved by: https://github.com/danielvegamyhre, https://github.com/atalman --- .../tensor/parallel/test_micro_pipeline_tp.py | 50 +++++++++++++++++++ .../_inductor/fx_passes/micro_pipeline_tp.py | 13 ++--- .../distributed/_symmetric_memory/__init__.py | 11 +++- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 906b7d1a4a52b..df3e2ffb38858 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -494,5 +494,55 @@ def test_dtensor_seq_par(self, shard_dim: int): self.assertNotIn("reduce_scatter_tensor", code) +@instantiate_parametrized_tests +class MicroPipelineTP4GPUTest(TestCase): + def setUp(self): + torch._inductor.config._micro_pipeline_tp = True + + self.rank = 0 + self.world_size = 4 + torch.cuda.set_device("cuda:0") + + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + dist.destroy_process_group() + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @fresh_cache() + def test_extra_collectives(self): + device_mesh = DeviceMesh( + "cuda", + torch.arange(0, self.world_size).view(2, -1), + mesh_dim_names=("tp", "other"), + ) + + def func(inp: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor: + hidden = all_gather_tensor(inp, 0, (device_mesh, 0)) @ w1.t() + full_hidden = all_gather_tensor(hidden, 0, (device_mesh, 1)) + full_hidden /= full_hidden.pow(2).sum().sqrt() + hidden = reduce_scatter_tensor(full_hidden, "avg", 0, (device_mesh, 1)) + return reduce_scatter_tensor(hidden @ w2.t(), "avg", 0, (device_mesh, 0)) + + inp = torch.rand(8, 10, device="cuda") + w1 = torch.rand(7, 10, device="cuda") + w2 = torch.rand(10, 7, device="cuda") + + with _test_mode(group_names={device_mesh["tp"].get_group().group_name}): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, inp, w1, w2) + + self.assertIn("fused_all_gather_matmul", code) + self.assertIn("all_gather_into_tensor", code) + self.assertIn("fused_matmul_reduce_scatter", code) + self.assertIn("reduce_scatter_tensor", code) + + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index af40d987f7d18..c4d935a4f8bb4 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -850,9 +850,11 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: Returns boolean indicating if fusion was successful or not. """ - assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), ( - "torch.distributed and NCCL must be available to use async tensor parallelism" - ) + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return from torch.distributed._symmetric_memory import ( is_symm_mem_enabled_for_group, @@ -875,9 +877,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: reduce_scatter.group_name, ) - assert is_symm_mem_enabled_for_group(group_name), ( - f"symmetric memory is not enabled for process group {group_name}, this is required for async TP" - ) + if not is_symm_mem_enabled_for_group(group_name): + return # Currently fused_matmul_reduce_scatter doesn't return the matmul result, # so we can't apply the fusion if the matmul result is used by multiple diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 634e953aeb36b..b45b902406ea8 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -47,10 +47,11 @@ def enable_symm_mem_for_group(group_name: str) -> None: _is_test_mode: bool = False +_mocked_group_names: Optional[set[str]] = None @contextmanager -def _test_mode() -> Generator[None, None, None]: +def _test_mode(group_names: Optional[set[str]] = None) -> Generator[None, None, None]: """ Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops defined in the ``symm_mem`` namespace to use fallback implementations. @@ -58,12 +59,16 @@ def _test_mode() -> Generator[None, None, None]: The context manager is not thread safe. """ global _is_test_mode + global _mocked_group_names prev = _is_test_mode + prev_group_names = _mocked_group_names try: _is_test_mode = True + _mocked_group_names = group_names yield finally: _is_test_mode = prev + _mocked_group_names = prev_group_names def is_symm_mem_enabled_for_group(group_name: str) -> bool: @@ -73,7 +78,9 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool: Args: group_name (str): the name of the process group. """ - return _is_test_mode or group_name in _group_name_to_store + if _is_test_mode: + return _mocked_group_names is None or group_name in _mocked_group_names + return group_name in _group_name_to_store _group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {} From 5cde34473c33ed7f8df07489783a2b86058ebb3f Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 18 Jul 2025 16:04:11 -0700 Subject: [PATCH 288/457] Fix `MakeTensor::computeStorageSize()` (#158690) For tensor with non-zero offset, it must be multiplied by element size Add regression test by creating Tensor in array of 6 elements with offset 3, which before the fix crashed with ``` C++ exception with description "setStorage: sizes [3, 3], strides [0, 1], storage offset 3, and itemsize 4 requiring a storage size of 24 are out of bounds for storage of size 15 Exception raised from checkInBoundsForStorage at /Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/Resize.h:123 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string, std::__1::allocator>) + 56 (0x104a9cd44 in libc10.dylib) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string, std::__1::allocator> const&) + 120 (0x104a9a05c in libc10.dylib) frame #2: void at::native::checkInBoundsForStorage(c10::ArrayRef, c10::ArrayRef, long long, caffe2::TypeMeta const&, c10::Storage const&) + 656 (0x111dbd314 in libtorch_cpu.dylib) frame #3: void at::native::setStrided(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, long long) + 152 (0x111dcd22c in libtorch_cpu.dylib) frame #4: at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::__1::optional) + 312 (0x111dccf98 in libtorch_cpu.dylib) frame #5: c10::impl::wrap_kernel_functor_unboxed_, c10::ArrayRef, std::__1::optional), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU__as_strided(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::__1::optional)>, at::Tensor, c10::guts::typelist::typelist, c10::ArrayRef, std::__1::optional>>, at::Tensor (at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::__1::optional)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::__1::optional) + 104 (0x1129a1e94 in libtorch_cpu.dylib) frame #6: at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::__1::optional) + 476 (0x112200ad0 in libtorch_cpu.dylib) frame #7: at::Tensor::as_strided(c10::ArrayRef, c10::ArrayRef, std::__1::optional) const + 236 (0x1115db098 in libtorch_cpu.dylib) frame #8: at::native::expand(at::Tensor const&, c10::ArrayRef, bool) + 348 (0x111dcc0d4 in libtorch_cpu.dylib) frame #9: c10::impl::wrap_kernel_functor_unboxed_, bool), &torch::ADInplaceOrView::(anonymous namespace)::expand(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool)>, at::Tensor, c10::guts::typelist::typelist, bool>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool) + 116 (0x1157ac410 in libtorch_cpu.dylib) frame #10: c10::impl::wrap_kernel_functor_unboxed_, bool), &torch::autograd::VariableType::(anonymous namespace)::expand(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool)>, at::Tensor, c10::guts::typelist::typelist, bool>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef, bool) + 992 (0x114e8b010 in libtorch_cpu.dylib) frame #11: at::_ops::expand::call(at::Tensor const&, c10::ArrayRef, bool) + 316 (0x112743c90 in libtorch_cpu.dylib) frame #12: at::expand_size(at::Tensor const&, c10::ArrayRef) + 164 (0x1047d82b4 in basic) frame #13: BasicTest_TestForBlobResizeCPU_Test::TestBody() + 284 (0x1047d8048 in basic) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158690 Approved by: https://github.com/angelayi --- aten/src/ATen/templates/Functions.cpp | 2 +- aten/src/ATen/test/basic.cpp | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index e111a88b3309e..b3c2164f1707e 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -64,7 +64,7 @@ Tensor TensorMaker::make_tensor() { if (strides_) { auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); if (storage_offset_) { - storage_size += storage_offset_.value(); + storage_size += storage_offset_.value() * itemsize; } return storage_size; } diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 8dd2e59ce2ddc..0e4f461cfd9a4 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -517,3 +517,12 @@ TEST(BasicTest, BasicStdTestCPU) { t3.join(); t4.join(); } + +TEST(BasicTest, TestForBlobResizeCPU) { + // Checks that for_blob can correctly create tensors with non-empty offset and resize them + std::array storage; + std::iota(storage.begin(), storage.end(), 1); + auto t = at::for_blob(storage.data(), {3,}).strides({1,}).storage_offset(3).options(c10::TensorOptions(kInt)).make_tensor(); + auto te = *at::expand_size(t, {3, 3}); + ASSERT_EQ(te[1][1].item(), 5); +} From 22d82222c6e2a2ef4badc6b816d233a4cec924c3 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Sat, 19 Jul 2025 05:41:01 +0000 Subject: [PATCH 289/457] GenAI Layer Benchmark (#158536) This PR adds GenAI layer benchmark. It compares pytorch eager, pytorch compiler, liger, and quack. It covers all kernels supported by [quack](https://github.com/Dao-AILab/quack?tab=readme-ov-file#kernels-) (CrossEntropy Fwd/Bwd, Softmax Fwd/Bwd, RMSNorm Fwd/Bwd, LayerNorm Fwd) and LayerNormBwd. ## Motivations - Many OSS users asked how to properly benchmark torch.compile generated kernels. One common error is to compile a kernel/layer for one shape (e.g., batch size=1) and benchmark for another shape (e.g., batch size = 1024), which leads to bad performance. This provides an simple & clear example for proper benchmark. - We recently added GenAI model benchmark (based on [vLLM](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm)). But it's usually hard to optimize models directly due to complexity. Layer benchmarks are easier to reason and optimize. ## Key Settings - Avoid reusing a kernel specializing on 1 shape for benchmark on another shape. ```python torch._dynamo.config.automatic_dynamic_shapes = False # Needed since changing args to function causes recompiles torch._dynamo.config.recompile_limit = 1000000 ``` - For forward, people may mark batch size as dynamic to avoid runtime recompilation. We respect the setting in this kernel-level benchmark. ``` torch._dynamo.mark_dynamic(x, 0) ``` GPU: H100 (devvm006.dkl0) Results: [P1874246170](https://www.internalfb.com/phabricator/paste/view/P1874246170) Note: for numerical accuracy, we use the default tolerance of torch.testing.assert_close (i.e., for `torch.bfloat16`, use rtol `1.6e-2` and atol `1e-5`). It shows numerical issues for some backends and kernels. Next step is to add roofline analysis, add to ci for checking regression, cover more GenAI Kernels, and include GenAI Layers for common fusion patterns. CrossEntropyBackward_bench CrossEntropyForward_bench LayerNormBackward_bench LayerNormForward_bench RMSNormBackward_bench RMSNormForward_bench SoftmaxBackward_bench SoftmaxForward_bench Pull Request resolved: https://github.com/pytorch/pytorch/pull/158536 Approved by: https://github.com/yf225, https://github.com/eellison --- benchmarks/dynamo/genai_layers/README.md | 23 + benchmarks/dynamo/genai_layers/benchmark.py | 152 +++++ benchmarks/dynamo/genai_layers/kernels.py | 635 ++++++++++++++++++ .../dynamo/genai_layers/requirements.txt | 4 + benchmarks/dynamo/genai_layers/utils.py | 241 +++++++ 5 files changed, 1055 insertions(+) create mode 100644 benchmarks/dynamo/genai_layers/README.md create mode 100644 benchmarks/dynamo/genai_layers/benchmark.py create mode 100644 benchmarks/dynamo/genai_layers/kernels.py create mode 100644 benchmarks/dynamo/genai_layers/requirements.txt create mode 100644 benchmarks/dynamo/genai_layers/utils.py diff --git a/benchmarks/dynamo/genai_layers/README.md b/benchmarks/dynamo/genai_layers/README.md new file mode 100644 index 0000000000000..d2a11e0acc213 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/README.md @@ -0,0 +1,23 @@ +# GenAI Kernel Benchmark + +This directory contains benchmarks for the GenAI kernels. It compares pytorch eager, pytorch compiler, quack, and liger. + + +## Setup + +Assuming pytorch is installed. + +``` +pip install -r requirements.txt +``` + +## Run + +``` + python benchmark.py --list # List all available benchmarks + python benchmark.py --all # Run all benchmarks + python benchmark.py cross_entropy_forward # Run specific benchmark + python benchmark.py softmax_forward softmax_backward # Run multiple benchmarks +``` + +Add `--visualize` to plot graph for the benchmark results. diff --git a/benchmarks/dynamo/genai_layers/benchmark.py b/benchmarks/dynamo/genai_layers/benchmark.py new file mode 100644 index 0000000000000..0378629670524 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/benchmark.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Benchmark runner for various kernel implementations. + +This script provides a command-line interface to run benchmarks for different +kernel implementations including CrossEntropy, Softmax, RMSNorm, and LayerNorm +kernels in both forward and backward directions. +""" + +import argparse +import sys + +from kernels import ( + BenchmarkKernel, + CrossEntropyBackward, + CrossEntropyForward, + LayerNormBackward, + LayerNormForward, + RMSNormBackward, + RMSNormForward, + SoftmaxBackward, + SoftmaxForward, +) + +import torch + + +torch._dynamo.config.automatic_dynamic_shapes = False +# Needed since changing args to function causes recompiles +torch._dynamo.config.recompile_limit = 1000000 + + +# Registry of all available benchmarks +BENCHMARK_REGISTRY: dict[str, type[BenchmarkKernel]] = { + "cross_entropy_forward": CrossEntropyForward, + "cross_entropy_backward": CrossEntropyBackward, + "softmax_forward": SoftmaxForward, + "softmax_backward": SoftmaxBackward, + "rmsnorm_forward": RMSNormForward, + "rmsnorm_backward": RMSNormBackward, + "layernorm_forward": LayerNormForward, + "layernorm_backward": LayerNormBackward, +} + + +def show_environment_info(): + """Show environment information.""" + print("Environment information:") + print(f" Python version: {sys.version}") + print(f" PyTorch version: {torch.__version__}") + print(f" CUDA version: {torch.version.cuda}") + + +def list_benchmarks(): + """List all available benchmarks.""" + print(f"Available benchmarks: {list(BENCHMARK_REGISTRY.keys())}") + + +def run_benchmark(benchmark_name: str, should_visualize: bool = False): + """Run a specific benchmark.""" + if benchmark_name not in BENCHMARK_REGISTRY: + print(f"Error: Unknown benchmark '{benchmark_name}'") + print("Use --list to see available benchmarks") + return False + + print(f"Running benchmark: {benchmark_name}") + print("=" * 60) + + benchmark_class = BENCHMARK_REGISTRY[benchmark_name] + benchmark = benchmark_class() + benchmark.benchmark() + if should_visualize: + benchmark.visualize() + + return True + + +def run_all_benchmarks(should_visualize: bool = False): + """Run all available benchmarks.""" + print("Running all benchmarks...") + print("=" * 60) + + for name, cls in BENCHMARK_REGISTRY.items(): + print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") + benchmark = cls() + benchmark.benchmark() + if should_visualize: + benchmark.visualize() + print() + + +def main(): + show_environment_info() + + parser = argparse.ArgumentParser( + description="Benchmark runner for kernel implementations", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python benchmark.py --list # List all available benchmarks + python benchmark.py --all # Run all benchmarks + python benchmark.py cross_entropy_forward # Run specific benchmark + python benchmark.py softmax_forward softmax_backward # Run multiple benchmarks + """, + ) + + parser.add_argument( + "benchmarks", + nargs="*", + help="Names of benchmarks to run (use --list to see available options)", + ) + + parser.add_argument( + "--list", action="store_true", help="List all available benchmarks" + ) + + parser.add_argument( + "--all", action="store_true", help="Run all available benchmarks" + ) + + parser.add_argument( + "--visualize", + action="store_true", + help="Visualize results after running benchmarks", + ) + + args = parser.parse_args() + + # Handle list option + if args.list: + list_benchmarks() + return + + # Handle all option + if args.all: + run_all_benchmarks(args.visualize) + return + + # Handle specific benchmarks + if not args.benchmarks: + print("Error: No benchmarks specified") + print("Use --list to see available benchmarks or --all to run all benchmarks") + parser.print_help() + sys.exit(1) + + for benchmark_name in args.benchmarks: + run_benchmark(benchmark_name, args.visualize) + print() # Add spacing between benchmarks + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/genai_layers/kernels.py b/benchmarks/dynamo/genai_layers/kernels.py new file mode 100644 index 0000000000000..30a5f21eaef81 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/kernels.py @@ -0,0 +1,635 @@ +from typing import Any + +import cutlass +import cutlass.torch as cutlass_torch +from utils import BenchmarkKernel + +import torch +import torch.nn.functional as F + + +class CrossEntropyForward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Read x (M*N elements) + read target (M elements) + write loss (M elements) + x, target = args + M, N = x.shape + dtype = x.dtype + return (M * N + M + M) * dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + return lambda: F.cross_entropy(x, target, reduction="none") + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(target, 0) + + # Need `lambda` otherwise torch.compile will not trace the function. + # More discussion: https://github.com/pytorch/pytorch/issues/158455 + compiled_cross_entropy = torch.compile( + lambda x, target: F.cross_entropy(x, target, reduction="none"), + mode="max-autotune-no-cudagraphs", + ) + return lambda: compiled_cross_entropy(x, target) + + def quack(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + from quack.cross_entropy import _cross_entropy + + return lambda: _cross_entropy(x, target) + + def liger(self, args, kwargs=None) -> Any: + assert kwargs is None + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + + x, target = args + cross_entropy = LigerCrossEntropyLoss(reduction="none") + return lambda: cross_entropy(x, target) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"\n Tensor dimensions: [{M}, {N}]") + # quack requires cutlass dtype + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype) + target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64) + self.benchmark_single_shape((x, target), setting=f"shape: [{M}, {N}]") + + def check_accuracy(self, args, kwargs) -> None: + res = {} + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() + gold = res["eager"] + for backend in self.available_backends: + if backend == "eager": + continue + if backend == "quack": + # quack's cross_entropy only returns float32 loss output. + # Need to convert it to the same dtype as gold for comparison. + res[backend] = res[backend].to(gold.dtype) + try: + torch.testing.assert_close(res[backend], gold) + print( + f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" + ) + except Exception as e: + print( + f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" + ) + + +class CrossEntropyBackward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Read x (M*N elements) + read target (M elements) + read dloss (M elements) + write grad(M*N elements) + x, target, dloss = args + # Memory ba + M, N = x.shape + return ( + 2 * M * N * x.dtype.itemsize + + M * target.dtype.itemsize + + M * dloss.dtype.itemsize + ) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target, dloss = args + loss = F.cross_entropy(x, target, reduction="none") + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target, dloss = args + + compiled_cross_entropy = torch.compile( + lambda x, target: F.cross_entropy(x, target, reduction="none"), + mode="max-autotune-no-cudagraphs", + ) + loss = compiled_cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def quack(self, args, kwargs=None) -> Any: + from quack.cross_entropy import cross_entropy + + assert kwargs is None + x, target, dloss = args + loss = cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def liger(self, args, kwargs=None) -> Any: + assert kwargs is None + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + + x, target, dloss = args + cross_entropy = LigerCrossEntropyLoss(reduction="none") + loss = cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn( + M, N, device="cuda", dtype=torch_dtype, requires_grad=True + ) + target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64) + dloss = torch.randn(M, device="cuda", dtype=torch.float32) + self.benchmark_single_shape( + (x, target, dloss), setting=f"shape: [{M}, {N}]" + ) + + +class SoftmaxForward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + (x,) = args + M, N = x.shape + return 2 * M * N * x.dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + (x,) = args + return lambda: F.softmax(x, dim=-1) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + (x,) = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_softmax = torch.compile( + lambda x: F.softmax(x, dim=-1), mode="max-autotune-no-cudagraphs" + ) + return lambda: compiled_softmax(x) + + def quack(self, args, kwargs=None) -> Any: + from quack.softmax import softmax + + assert kwargs is None + (x,) = args + return lambda: softmax(x) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.softmax import LigerSoftmax + + assert kwargs is None + (x,) = args + softmax = LigerSoftmax().to("cuda") + return lambda: softmax(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x,), setting=f"shape: [{M}, {N}]") + + +class SoftmaxBackward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Memory: read dy and y, write ax backward + x, dy = args + M, N = x.shape + return 3 * M * N * x.dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, dy = args + y = F.softmax(x, dim=-1) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, dy = args + compiled_softmax = torch.compile( + lambda x: F.softmax(x, dim=-1), mode="max-autotune-no-cudagraphs" + ) + y = compiled_softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def quack(self, args, kwargs=None) -> Any: + from quack.softmax import softmax + + assert kwargs is None + x, dy = args + + y = softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.softmax import LigerSoftmax + + assert kwargs is None + x, dy = args + softmax = LigerSoftmax().to("cuda") + y = softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn( + M, N, device="cuda", dtype=torch_dtype, requires_grad=True + ) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, dy), setting=f"shape: [{M}, {N}]") + + +class RMSNormForward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w = args + M, N = x.shape + return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize + + def rms_norm_ref(self, x, w): + x_f32 = x.float() + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * w + ).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + return lambda: self.rms_norm_ref(x, w) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_rms_norm = torch.compile( + self.rms_norm_ref, mode="max-autotune-no-cudagraphs" + ) + return lambda: compiled_rms_norm(x, w) + + def quack(self, args, kwargs=None) -> Any: + # Note: only supper weight with float32 dtype + from quack.rmsnorm import _rmsnorm_fwd + + x, w = args + return lambda: _rmsnorm_fwd(x, w, eps=1e-6) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.rms_norm import LigerRMSNorm + + x, w = args + M, N = x.shape + liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda() + liger_rmsnorm.weight.data.copy_(w) + return lambda: liger_rmsnorm(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + w = torch.randn(N, device="cuda", dtype=torch.float32) + self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]") + + +class RMSNormBackward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # TODO: OOM for (32768, 65536) on h100 + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w, dy = args + # x, dy: [M, N], w: [N] + M, N = x.shape + # Read x, w, dy, write dx, dw + return 3 * M * N * x.dtype.itemsize + 2 * N * w.dtype.itemsize + + def rms_norm_ref(self, x, w): + x_f32 = x.float() + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * w + ).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = self.rms_norm_ref(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = torch.compile(self.rms_norm_ref, mode="max-autotune-no-cudagraphs")(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def quack(self, args, kwargs=None) -> Any: + from quack.rmsnorm import _rmsnorm_backward + + ( + x, + w, + dy, + ) = args + M, N = x.shape + rstd = torch.randn(M, device="cuda", dtype=torch.float32) + return lambda: _rmsnorm_backward(x, w, dy, rstd) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.rms_norm import LigerRMSNorm + + x, w, dy = args + M, N = x.shape + liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda() + liger_rmsnorm.weight.data.copy_(w) + y = liger_rmsnorm(x) + return lambda: torch.autograd.grad( + y, [x, liger_rmsnorm.weight], grad_outputs=dy, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True) + w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]") + + +class LayerNormForward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # OOM for (16384, 131072) on h100 + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w = args + M, N = x.shape + # Read x ([M, N]), w ([N]), write y ([M, N]) + return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize + + def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + x_f32 = x.float() + return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + return lambda: self.layernorm_ref(x, w) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_layernorm = torch.compile( + self.layernorm_ref, mode="max-autotune-no-cudagraphs" + ) + return lambda: compiled_layernorm(x, w, eps=1e-6) + + def quack(self, args, kwargs) -> Any: + # Note: quack layernorm does not support bias + from quack.layernorm import layernorm + + x, w = args + return lambda: layernorm(x, w, eps=1e-6) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.layer_norm import LigerLayerNorm + + x, w = args + M, N = x.shape + liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda() + liger_layernorm.weight.data.copy_(w) + liger_layernorm.bias.data.copy_( + torch.zeros(N, device="cuda", dtype=torch.float32) + ) + return lambda: liger_layernorm(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + w = torch.randn(N, device="cuda", dtype=torch.float32) + self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]") + + +class LayerNormBackward(BenchmarkKernel): + def __init__(self): + super().__init__() + self.available_backends = ["eager", "compiled", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # OOM for (16384, 131072), (8192, 262144) + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w, dy = args + M, N = x.shape + # Read x ([M, N]), w ([N]), dy ([M, N]), write dx ([M, N]), dw ([N]) + return ( + 2 * M * N * x.dtype.itemsize + + 2 * N * w.dtype.itemsize + + M * N * dy.dtype.itemsize + ) + + def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + x_f32 = x.float() + return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = self.layernorm_ref(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + compiled_layernorm = torch.compile( + self.layernorm_ref, mode="max-autotune-no-cudagraphs" + ) + y = compiled_layernorm(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.layer_norm import LigerLayerNorm + + x, w, dy = args + M, N = x.shape + liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda() + liger_layernorm.weight.data.copy_(w) + liger_layernorm.bias.data.copy_( + torch.zeros(N, device="cuda", dtype=torch.float32) + ) + y = liger_layernorm(x) + return lambda: torch.autograd.grad( + y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True) + w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]") diff --git a/benchmarks/dynamo/genai_layers/requirements.txt b/benchmarks/dynamo/genai_layers/requirements.txt new file mode 100644 index 0000000000000..ddd1f01013495 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/requirements.txt @@ -0,0 +1,4 @@ +quack-kernels +liger-kernel +nvidia-cutlass-dsl==4.1.0.dev0 +matplotlib diff --git a/benchmarks/dynamo/genai_layers/utils.py b/benchmarks/dynamo/genai_layers/utils.py new file mode 100644 index 0000000000000..9d3f97c0da749 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/utils.py @@ -0,0 +1,241 @@ +import os +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import matplotlib.pyplot as plt + +import torch +from torch._inductor.runtime.benchmarking import benchmarker + + +def benchmark_kernel_in_milliseconds(func: Callable, *args, **kwargs) -> float: + # warmup + for _ in range(5): + func(*args, **kwargs) + return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) + + +@dataclass +class Performance: + # Benchmark setting usually the shape of the input tensor + setting: str + + # Latency in milliseconds + latency: float + + # Number of memory access in bytes + memory_bytes: float + + # Memory bandwidth in GB/s + memory_bandwidth: float = 0.0 + + # Compute intensity in FLOPs/byte + compute_intensity: float = 0.0 + + def __post_init__(self): + self.memory_bandwidth = self.memory_bytes / (self.latency / 1000) / 1e9 + + def __str__(self): + return f"setting: {self.setting}, latency: {self.latency} ms, memory bandwidth: {self.memory_bandwidth} GB/s" + + +class BenchmarkKernel: + def __init__(self): + self.name = self.__class__.__name__ + self.available_backends: list[str] = [] + + # mapping from backend to list of performance results + self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list) + + def get_memory_bytes(self, args, kwargs) -> int: + # Get the necessary memory access in bytes for the kernelßß + raise NotImplementedError + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # Get a list of input shapes to benchmark the kernel + raise NotImplementedError + + def eager(self, args, kwargs) -> Any: + raise NotImplementedError + + def compiled(self, args, kwargs) -> Any: + raise NotImplementedError + + def helion(self, args, kwargs) -> Any: + raise NotImplementedError + + def quack(self, args, kwargs) -> Any: + raise NotImplementedError + + def liger(self, args, kwargs) -> Any: + raise NotImplementedError + + def triton(self, args, kwargs) -> Any: + raise NotImplementedError + + def benchmark(self): + raise NotImplementedError + + def clone_inputs(self, args, kwargs) -> Any: + args_ref = [ + arg.clone().detach().requires_grad_(arg.requires_grad) for arg in args + ] + + kwargs_ref = ( + { + k: ( + v.clone().detach().requires_grad_(v.requires_grad) + if isinstance(v, torch.Tensor) + else v + ) + for k, v in kwargs.items() + } + if kwargs + else kwargs + ) + + return args_ref, kwargs_ref + + def check_accuracy(self, args, kwargs) -> None: + res = {} + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() + gold = res["eager"] + for backend in self.available_backends: + if backend == "eager": + continue + try: + torch.testing.assert_close(res[backend], gold) + for t, gold_t in zip(res[backend], gold): + if t.requires_grad: + torch.testing.assert_close(t.grad, gold_t.grad) + print( + f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" + ) + except Exception as e: + print( + f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" + ) + + def benchmark_single_shape( + self, args, kwargs=None, should_check_accuracy=True, setting: str = "" + ): + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + try: + avg_time = benchmark_kernel_in_milliseconds( + getattr(self, backend)(args_ref, kwargs_ref) + ) + except Exception as e: + print( + f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}" + ) + self.available_backends.remove(backend) + continue + mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref) + perf = Performance(setting, avg_time, mem_bytes) + print(f"{self.name} kernel on {backend} backend. {perf}") + self.profiling_results[backend].append(perf) + + if should_check_accuracy: + self.check_accuracy(args, kwargs) + + def visualize(self) -> None: + visualize_comparison( + self.profiling_results, + title=f"{self.name}", + output_path=f"{self.name}_bench", + ) + return + + +def get_backend_colors() -> dict[str, str]: + """Get consistent color scheme for different backends.""" + return { + "eager": "#1f77b4", # blue + "compiled": "#ff7f0e", # orange + "quack": "#2ca02c", # green + "liger": "#d62728", # red + "helion": "#9467bd", # purple + "triton": "#8c564b", # brown + "cutlass": "#e377c2", # pink + "flash_attn": "#7f7f7f", # gray + "default": "#000000", # black + } + + +def visualize_comparison( + profiling_results: dict[str, list[Performance]], + title: Optional[str] = None, + output_path: Optional[str] = None, +) -> None: + """ + Create a single memory_bandwidth comparison plot from profiling results. + + Args: + profiling_results: Dict mapping backend names to lists of Performance objects + output_path: Path to save the plot (optional) + """ + # Get backend colors + backend_colors = get_backend_colors() + + # Extract settings from eager backend which runs all settings + all_settings = [] + for perf in profiling_results["eager"]: + all_settings.append(perf.setting) + + # Create single plot + fig, ax = plt.subplots(1, 1, figsize=(12, 8)) + + for backend in profiling_results: + backend_perfs = profiling_results[backend] + perf_dict = {perf.setting: perf for perf in backend_perfs} + + x_vals = [] + y_vals = [] + for i, setting in enumerate(all_settings): + if setting in perf_dict: + x_vals.append(i) + y_vals.append(perf_dict[setting].memory_bandwidth) + + if x_vals: # Only plot if we have data + color = backend_colors.get(backend, backend_colors["default"]) + ax.plot( + x_vals, + y_vals, + "o-", + label=backend, + color=color, + linewidth=2, + markersize=8, + alpha=0.8, + ) + + # Configure the plot + ax.set_title(title or "Memory Bandwidth Comparison", fontsize=16) + ax.set_xlabel("Shape", fontsize=12) + ax.set_ylabel("memory bandwidth (GB/s)", fontsize=12) + ax.set_xticks(range(len(all_settings))) + ax.set_xticklabels( + [ + s.replace("shape: ", "").replace("[", "").replace("]", "") + for s in all_settings + ], + rotation=45, + ha="right", + ) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot if output path is provided + if output_path: + # Save as PNG + os.makedirs("pics", exist_ok=True) + full_path = os.path.join("pics", output_path + ".png") + plt.savefig(full_path, dpi=300, bbox_inches="tight", facecolor="white") + + plt.close() From a9f84021fb5963019f3df895d7d3eeae4606cf79 Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Sat, 19 Jul 2025 06:51:53 +0000 Subject: [PATCH 290/457] [CI] Fixes CI for CUDA Version > 12.9 (#157385) Compute capabilities older than volta (inclusive) is no longer supported in CUDA Version > 12.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157385 Approved by: https://github.com/eqy --- test/test_cpp_extensions_jit.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index d671e3f874c96..06b681bee981f 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -322,12 +322,15 @@ def test_jit_cuda_archflags(self): [f"{capability[0]}{capability[1]}" for capability in capabilities], None, ), - "Maxwell+Tegra;6.1": (["53", "61"], None), - "Volta": (["70"], ["70"]), } archflags["7.5+PTX"] = (["75"], ["75"]) - archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) - if int(torch.version.cuda.split(".")[0]) < 12: + major, minor = map(int, torch.version.cuda.split(".")[:2]) + if major < 12 or (major == 12 and minor <= 9): + # Compute capability <= 7.0 is only supported up to CUDA 12.9 + archflags["Maxwell+Tegra;6.1"] = (["53", "61"], None) + archflags["Volta"] = (["70"], ["70"]) + archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) + if major < 12: # CUDA 12 drops compute capability < 5.0 archflags["Pascal 3.5"] = (["35", "60", "61"], None) From f73594164a3825dc4354ee2ba0fa231195f49bda Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 18 Jul 2025 12:38:47 -0700 Subject: [PATCH 291/457] [BE] document Adadelta and Adagrad APIs properly (#158483) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158483 Approved by: https://github.com/albanD --- docs/source/conf.py | 12 ------------ docs/source/optim.aliases.md | 36 ++++++++++++++++++++++++++++++++++++ docs/source/optim.md | 10 +++++++--- torch/optim/adagrad.py | 3 ++- 4 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 docs/source/optim.aliases.md diff --git a/docs/source/conf.py b/docs/source/conf.py index 2113411cd8afb..6e498b625da0d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1840,12 +1840,6 @@ "check_export_model_diff", "verify", "verify_aten_graph", - # torch.optim.adadelta - "adadelta", - # torch.optim.adagrad - "adagrad", - # torch.optim.adam - "adam", # torch.optim.adamax "adamax", # torch.optim.adamw @@ -3112,12 +3106,6 @@ # torch.onnx.verification "OnnxBackend", "OnnxTestCaseRepro", - # torch.optim.adadelta - "Adadelta", - # torch.optim.adagrad - "Adagrad", - # torch.optim.adam - "Adam", # torch.optim.adamax "Adamax", # torch.optim.adamw diff --git a/docs/source/optim.aliases.md b/docs/source/optim.aliases.md new file mode 100644 index 0000000000000..6b21e6f9e30aa --- /dev/null +++ b/docs/source/optim.aliases.md @@ -0,0 +1,36 @@ +# Aliases in torch.optim + +The following are aliases to their counterparts in ``torch.optim`` in the nested namespaces in which they are defined. For any of these APIs, feel free to use the top-level version in ``torch.optim`` like ``torch.optim.Adam`` or the nested version ``torch.optim.adam.Adam``. + +```{eval-rst} +.. automodule:: torch.optim.adadelta +.. currentmodule:: torch.optim.adadelta +.. autosummary:: + :toctree: generated + :nosignatures: + + Adadelta + adadelta +``` + +```{eval-rst} +.. automodule:: torch.optim.adagrad +.. currentmodule:: torch.optim.adagrad +.. autosummary:: + :toctree: generated + :nosignatures: + + Adagrad + adagrad +``` + +```{eval-rst} +.. automodule:: torch.optim.adam +.. currentmodule:: torch.optim.adam +.. autosummary:: + :toctree: generated + :nosignatures: + + Adam + adam +``` diff --git a/docs/source/optim.md b/docs/source/optim.md index 8a3f03468810d..b72d002723259 100644 --- a/docs/source/optim.md +++ b/docs/source/optim.md @@ -688,9 +688,6 @@ We train the model for a total of 300 epochs and start to collect EMA averages i ```{eval-rst} -.. py:module:: torch.optim.adadelta -.. py:module:: torch.optim.adagrad -.. py:module:: torch.optim.adam .. py:module:: torch.optim.adamax .. py:module:: torch.optim.adamw .. py:module:: torch.optim.asgd @@ -705,3 +702,10 @@ for tracking purposes --> .. py:module:: torch.optim.sparse_adam .. py:module:: torch.optim.swa_utils ``` + +```{eval-rst} +.. toctree:: + :hidden: + + optim.aliases.md +``` diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 65e76634421a3..00742b8a4e075 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -21,7 +21,7 @@ ) -__all__ = ["Adagrad", "adagrad"] +__all__: list[str] = ["Adagrad", "adagrad"] class Adagrad(Optimizer): @@ -117,6 +117,7 @@ def __setstate__(self, state): ) def share_memory(self): + """Calls tensor.share_memory_() on the state sum tensors.""" for group in self.param_groups: for p in group["params"]: state = self.state[p] From 7cc5d03dfc0077bc670c39abd101c72a04b2737f Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 18 Jul 2025 12:55:28 -0700 Subject: [PATCH 292/457] Document the rest of the specific optimizer module APIs (#158669) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158669 Approved by: https://github.com/albanD ghstack dependencies: #158483 --- docs/source/conf.py | 31 ---------- docs/source/optim.aliases.md | 108 +++++++++++++++++++++++++++++++++++ docs/source/optim.md | 10 ---- torch/optim/adagrad.py | 2 +- 4 files changed, 109 insertions(+), 42 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6e498b625da0d..8b2112c165e8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1840,25 +1840,9 @@ "check_export_model_diff", "verify", "verify_aten_graph", - # torch.optim.adamax - "adamax", - # torch.optim.adamw - "adamw", - # torch.optim.asgd - "asgd", - # torch.optim.nadam - "nadam", # torch.optim.optimizer "register_optimizer_step_post_hook", "register_optimizer_step_pre_hook", - # torch.optim.radam - "radam", - # torch.optim.rmsprop - "rmsprop", - # torch.optim.rprop - "rprop", - # torch.optim.sgd - "sgd", # torch.optim.swa_utils "get_ema_avg_fn", "get_ema_multi_avg_fn", @@ -3131,23 +3115,8 @@ "ReduceLROnPlateau", "SequentialLR", "StepLR", - # torch.optim.nadam - "NAdam", # torch.optim.optimizer "Optimizer", - # torch.optim.radam - "RAdam", - # torch.optim.rmsprop - "RMSprop", - # torch.optim.rprop - "Rprop", - # torch.optim.sgd - "SGD", - # torch.optim.sparse_adam - "SparseAdam", - # torch.optim.swa_utils - "AveragedModel", - "SWALR", # torch.overrides "BaseTorchFunctionMode", "TorchFunctionMode", diff --git a/docs/source/optim.aliases.md b/docs/source/optim.aliases.md index 6b21e6f9e30aa..09616aefe14ae 100644 --- a/docs/source/optim.aliases.md +++ b/docs/source/optim.aliases.md @@ -34,3 +34,111 @@ The following are aliases to their counterparts in ``torch.optim`` in the nested Adam adam ``` + +```{eval-rst} +.. automodule:: torch.optim.adamax +.. currentmodule:: torch.optim.adamax +.. autosummary:: + :toctree: generated + :nosignatures: + + Adamax + adamax +``` + +```{eval-rst} +.. automodule:: torch.optim.adamw +.. currentmodule:: torch.optim.adamw +.. autosummary:: + :toctree: generated + :nosignatures: + + AdamW + adamw +``` + +```{eval-rst} +.. automodule:: torch.optim.asgd +.. currentmodule:: torch.optim.asgd +.. autosummary:: + :toctree: generated + :nosignatures: + + ASGD + asgd +``` + +```{eval-rst} +.. automodule:: torch.optim.lbfgs +.. currentmodule:: torch.optim.lbfgs +.. autosummary:: + :toctree: generated + :nosignatures: + + LBFGS +``` + +```{eval-rst} +.. automodule:: torch.optim.nadam +.. currentmodule:: torch.optim.nadam +.. autosummary:: + :toctree: generated + :nosignatures: + + NAdam + nadam +``` + +```{eval-rst} +.. automodule:: torch.optim.radam +.. currentmodule:: torch.optim.radam +.. autosummary:: + :toctree: generated + :nosignatures: + + RAdam + radam +``` + +```{eval-rst} +.. automodule:: torch.optim.rmsprop +.. currentmodule:: torch.optim.rmsprop +.. autosummary:: + :toctree: generated + :nosignatures: + + RMSprop + rmsprop +``` + +```{eval-rst} +.. automodule:: torch.optim.rprop +.. currentmodule:: torch.optim.rprop +.. autosummary:: + :toctree: generated + :nosignatures: + + Rprop + rprop +``` + +```{eval-rst} +.. automodule:: torch.optim.sgd +.. currentmodule:: torch.optim.sgd +.. autosummary:: + :toctree: generated + :nosignatures: + + SGD + sgd +``` + +```{eval-rst} +.. automodule:: torch.optim.sparse_adam +.. currentmodule:: torch.optim.sparse_adam +.. autosummary:: + :toctree: generated + :nosignatures: + + SparseAdam +``` diff --git a/docs/source/optim.md b/docs/source/optim.md index b72d002723259..38587705ed216 100644 --- a/docs/source/optim.md +++ b/docs/source/optim.md @@ -688,18 +688,8 @@ We train the model for a total of 300 epochs and start to collect EMA averages i ```{eval-rst} -.. py:module:: torch.optim.adamax -.. py:module:: torch.optim.adamw -.. py:module:: torch.optim.asgd -.. py:module:: torch.optim.lbfgs .. py:module:: torch.optim.lr_scheduler -.. py:module:: torch.optim.nadam .. py:module:: torch.optim.optimizer -.. py:module:: torch.optim.radam -.. py:module:: torch.optim.rmsprop -.. py:module:: torch.optim.rprop -.. py:module:: torch.optim.sgd -.. py:module:: torch.optim.sparse_adam .. py:module:: torch.optim.swa_utils ``` diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 00742b8a4e075..00b3c9c28774f 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -21,7 +21,7 @@ ) -__all__: list[str] = ["Adagrad", "adagrad"] +__all__ = ["Adagrad", "adagrad"] class Adagrad(Optimizer): From 7cc1a9546c135f8e7635e0d38aa2bba797f8907d Mon Sep 17 00:00:00 2001 From: Xu Han Date: Sat, 19 Jul 2025 08:58:42 +0000 Subject: [PATCH 293/457] [AOTI] fix extract file failed on Windows. (#158702) Changes: 1. rename zip index name, and keep it out of normalize path. 2. normalize output path for extract file. Extract files successful: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158702 Approved by: https://github.com/angelayi --- .../aoti_package/model_package_loader.cpp | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 127969c0318ff..8e3a2d95fb9ec 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -435,19 +435,21 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::vector found_filenames; for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { - uint32_t filename_len = + uint32_t zip_filename_len = mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); - if (filename_len == 0) { + if (zip_filename_len == 0) { throw std::runtime_error("Failed to read filename"); } - // filename_len returned by mz_zip_reader_get_filename includes the null - // terminator, so we need to subtract 1 here - std::string filename_str(filename_len - 1, '\0'); + // zip_filename_len returned by mz_zip_reader_get_filename includes the null + // terminator, so we need to subtract 1 here. + std::string zip_filename_str(zip_filename_len - 1, '\0'); + // zip_filename_str can't be normalize_path_separator, because it should be + // as index for mz_zip_reader_extract_file_to_file. if (!mz_zip_reader_get_filename( - &zip_archive, i, filename_str.data(), filename_len)) { + &zip_archive, i, zip_filename_str.data(), zip_filename_len)) { throw std::runtime_error("Failed to read filename"); } - found_filenames.push_back(normalize_path_separator(filename_str)); + found_filenames.push_back(zip_filename_str); } if (found_filenames.empty()) { @@ -504,18 +506,17 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } - output_path_str = normalize_path_separator(output_path_str); - + std::string output_file_path = normalize_path_separator(output_path_str); LOG(INFO) << "Extract file: " << filename_str << " to " - << output_path_str; + << output_file_path; // Create the parent directory if it doesn't exist - size_t parent_path_idx = output_path_str.find_last_of(k_separator); + size_t parent_path_idx = output_file_path.find_last_of(k_separator); if (parent_path_idx == std::string::npos) { throw std::runtime_error( - "Failed to find parent path in " + output_path_str); + "Failed to find parent path in " + output_file_path); } - std::string parent_path = output_path_str.substr(0, parent_path_idx); + std::string parent_path = output_file_path.substr(0, parent_path_idx); if (!recursive_mkdir(parent_path)) { throw std::runtime_error(fmt::format( "Failed to create directory {}: {}", @@ -525,22 +526,22 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // Extracts file to the temp directory mz_bool b_extract = mz_zip_reader_extract_file_to_file( - &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); + &zip_archive, filename_str.c_str(), output_file_path.c_str(), 0); if (b_extract == MZ_FALSE) { throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", filename_str, output_path_str)); + "Failed to extract file {} to {}", filename_str, output_file_path)); } // Save the file for bookkeeping - size_t extension_idx = output_path_str.find_last_of('.'); + size_t extension_idx = output_file_path.find_last_of('.'); if (extension_idx != std::string::npos) { - std::string filename_extension = output_path_str.substr(extension_idx); + std::string filename_extension = output_file_path.substr(extension_idx); if (filename_extension == ".cpp") { - cpp_filename = output_path_str; + cpp_filename = output_file_path; } else if (filename_extension == object_file_ext()) { - obj_filenames.push_back(output_path_str); + obj_filenames.push_back(output_file_path); } else if (filename_extension == extension_file_ext()) { - so_filename = output_path_str; + so_filename = output_file_path; } } } From d36afac83b01c3de214db91f3d4b3f447f9a77b7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sat, 19 Jul 2025 10:48:06 -0700 Subject: [PATCH 294/457] Build domain libraries for all workflows with TorchBench config (#158601) They are expensive GPU runners and should not spend time building packages Signed-off-by: Huy Do Pull Request resolved: https://github.com/pytorch/pytorch/pull/158601 Approved by: https://github.com/ZainRizvi --- .ci/docker/requirements-ci.txt | 6 +++++ .ci/pytorch/build.sh | 24 +++++++++---------- .ci/pytorch/common_utils.sh | 7 ++++++ .github/workflows/inductor-nightly.yml | 1 + .github/workflows/inductor-perf-compare.yml | 1 + .../inductor-perf-test-nightly-aarch64.yml | 1 + .../inductor-perf-test-nightly-h100.yml | 1 + .../inductor-perf-test-nightly-x86.yml | 1 + .../workflows/inductor-perf-test-nightly.yml | 3 +++ .github/workflows/inductor-periodic.yml | 5 ++++ .github/workflows/inductor.yml | 2 ++ .github/workflows/test-h100.yml | 2 +- 12 files changed, 40 insertions(+), 14 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 650c4e58c8ba6..facc633f6a7ad 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -390,3 +390,9 @@ tlparse==0.3.30 cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" #Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits. #test that import: test_cuda.py + +setuptools-git-versioning==2.1.0 +scikit-build==0.18.1 +pyre-extensions==0.0.32 +tabulate==0.9.0 +#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 9e5d2c4675eed..f2b8998a6f6cd 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -309,22 +309,20 @@ else fi pip_install_whl "$(echo dist/*.whl)" - if [[ -n "${BUILD_ADDITIONAL_PACKAGES}" ]]; then - if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *vision* ]]; then - install_torchvision - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then + install_torchvision + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *audio* ]]; then - install_torchaudio - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *audio* ]]; then + install_torchaudio + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES}" == *fbgemm* ]]; then - install_torchrec_and_fbgemm - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES:-}" == *fbgemm* ]]; then + install_torchrec_and_fbgemm + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES}" == *torchao* ]]; then - install_torchao - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchao* ]]; then + install_torchao fi if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 69a5b7ad37951..9075fe5fb56f8 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -174,7 +174,14 @@ function install_torchvision() { echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so fi + + if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then + # Not sure if both are needed, but why not + export FORCE_CUDA=1 + export WITH_CUDA=1 + fi pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision + if [ -n "${LD_PRELOAD}" ]; then LD_PRELOAD=${orig_preload} fi diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index d8dc7146fda13..c17a4ed6341aa 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -48,6 +48,7 @@ jobs: { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-test: diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 25191643b3599..628f624240127 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -43,6 +43,7 @@ jobs: { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test: diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index ed04d88eb1277..e16c8be79130d 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -116,6 +116,7 @@ jobs: { config: "inductor_torchbench_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "linux.arm64.m7g.metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio torchao" secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 4807f4a29b08a..ab651e081b7cd 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -119,6 +119,7 @@ jobs: { config: "inductor_torchbench_perf_cuda_h100", shard: 9, num_shards: 9, runner: "linux.aws.h100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test-periodically: diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 0466576658d45..62234e5f499a7 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -98,6 +98,7 @@ jobs: { config: "inductor_torchbench_perf_cpu_x86", shard: 4, num_shards: 4, runner: "linux.24xl.spr-metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly-freezing: diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 015204473339d..9fd81a5a05c9a 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -86,6 +86,8 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Every bit to make perf run faster helps + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' @@ -112,6 +114,7 @@ jobs: { config: "cachebench", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test-nightly: diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 2e16c2e403fb0..d3f1ff1f1dae9 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -58,6 +58,7 @@ jobs: { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-test: @@ -125,6 +126,7 @@ jobs: { include: [ { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-test: @@ -159,6 +161,7 @@ jobs: { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-test: @@ -195,6 +198,7 @@ jobs: { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: @@ -240,6 +244,7 @@ jobs: { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index e6fc7aa65431a..721572f1807ba 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -62,6 +62,7 @@ jobs: { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: @@ -94,6 +95,7 @@ jobs: { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index 40eff83ba58df..7e4a818c3528d 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: "linux.12xlarge" + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' From a1cfe7f1df3ddf72f1cf361cef96c69fc4a91a2b Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Sun, 20 Jul 2025 00:28:09 +0000 Subject: [PATCH 295/457] [nativert] benchmark util (#158678) Differential Revision: D78514241 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158678 Approved by: https://github.com/SherlockNoMad, https://github.com/georgiaphillips --- torch/nativert/executor/GraphExecutorBase.cpp | 11 +++++++++++ torch/nativert/executor/GraphExecutorBase.h | 3 +++ 2 files changed, 14 insertions(+) diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index 1c85e27253169..9a527cc8117bc 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -40,6 +40,13 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ProfileMetrics results; const auto numNodes = static_cast(nodeKernels_.size()); + + results.percentPerNode.resize(numNodes, 0.0f); + results.nodeTypes.reserve(numNodes); + for (const auto& nodeKernel : nodeKernels_) { + results.nodeTypes.emplace_back(nodeKernel->node()->target()); + } + results.timePerNode.resize(numNodes, 0); if (inputsList.empty()) { auto i = 0; @@ -114,6 +121,10 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( const std::string& target = r.first; results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime; } + for (const auto i : c10::irange(numNodes)) { + results.percentPerNode[i] = + results.timePerNode[i] * 100.0f / results.totalTime; + } return results; } diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h index 8d659f1588c2b..dfe020ebae29e 100644 --- a/torch/nativert/executor/GraphExecutorBase.h +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -14,12 +14,15 @@ struct ProfileMetrics { size_t staticDispatchNodesCount{0}; size_t totalNodesCount{0}; std::vector timePerNode; + std::vector nodeTypes; std::unordered_map timePerNodeType; std::unordered_map percentPerNodeType; + std::vector percentPerNode; std::unordered_map instancesPerNodeType; std::unordered_set staticDispatchNodes; std::unordered_set primNodes; float totalTime{0}; + std::string name; }; /** From b64f338da4c73a601254f97a9ac4be168423b925 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 16:36:07 -0300 Subject: [PATCH 296/457] [DLPack] add NumPy exchange tests. (#150216) This PR resolves an old TODO that requested NumPy DLPack exchange tests once version 1.22 was required. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150216 Approved by: https://github.com/msaroufim, https://github.com/albanD --- test/test_dlpack.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 389f63efa687f..36b8dcb7ca686 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -5,6 +5,7 @@ from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, + onlyCPU, onlyCUDA, onlyNativeDeviceTypes, skipCUDAIfRocm, @@ -317,6 +318,40 @@ def test(device, **kwargs): # Consumer should still be able to process a smaller version capsule. test(device, max_version=(2, 0)) + @skipMeta + @onlyCPU + @dtypes( + # Note: NumPy DLPack bool support only landed in 1.25. + *all_types_and_complex_and( + torch.half, + torch.uint16, + torch.uint32, + torch.uint64, + ) + ) + def test_numpy_dlpack_protocol_conversion(self, device, dtype): + import numpy as np + + t = make_tensor((5,), dtype=dtype, device=device) + + if hasattr(np, "from_dlpack"): + # DLPack support only available from NumPy 1.22 onwards. + # Here, we test having another framework (NumPy) calling our + # Tensor.__dlpack__ implementation. + arr = np.from_dlpack(t) + self.assertEqual(t, arr) + + # We can't use the array created above as input to from_dlpack. + # That's because DLPack imported NumPy arrays are read-only. + # Thus, we need to convert it to NumPy by using the numpy() method. + t_arr = t.numpy() + + # Transform the NumPy array back using DLPack. + res = from_dlpack(t_arr) + + self.assertEqual(t, res) + self.assertEqual(t.data_ptr(), res.data_ptr()) + instantiate_device_type_tests(TestTorchDlPack, globals()) From 1d526fe78fcb8e6185f25d4b4340599b35f6a6d9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 16:36:07 -0300 Subject: [PATCH 297/457] Fix DLPack stream logic. (#150217) This PR fixes the logic for dealing with CUDA and ROCm streams whenever we are trying to create a DLPack capsule from a tensor. In summary, this PR: - Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is called for a CUDA tensor. - Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor: PyTorch doesn't support the per-thread default stream. - Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or 2, is called for a CUDA tensor using ROCm. For more details, see [the documentation][1]. [1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html Pull Request resolved: https://github.com/pytorch/pytorch/pull/150217 Approved by: https://github.com/msaroufim, https://github.com/albanD ghstack dependencies: #150216 --- test/test_dlpack.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ torch/_tensor.py | 56 ++++++++++++++++++++++++++++++------------- 2 files changed, 97 insertions(+), 17 deletions(-) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 36b8dcb7ca686..3437f16fdc740 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -3,11 +3,13 @@ import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, dtypes, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIfNotRocm, skipCUDAIfRocm, skipMeta, ) @@ -242,6 +244,62 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_cuda_per_thread_stream(self, device): + # Test whether we raise an error if we are trying to use per-thread default + # stream, which is currently not supported by PyTorch. + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex( + BufferError, "per-thread default stream is not supported" + ): + x.__dlpack__(stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfNotRocm + def test_dlpack_invalid_rocm_streams(self, device): + # Test that we correctly raise errors on unsupported ROCm streams. + def test(x, stream): + with self.assertRaisesRegex( + AssertionError, r"unsupported stream on ROCm: \d" + ): + x.__dlpack__(stream=stream) + + x = make_tensor((5,), dtype=torch.float32, device=device) + test(x, stream=1) + test(x, stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_invalid_cuda_streams(self, device): + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex(AssertionError, r"unsupported stream on CUDA: \d"): + x.__dlpack__(stream=0) + + @skipMeta + def test_dlpack_invalid_cpu_stream(self): + x = make_tensor((5,), dtype=torch.float32, device="cpu") + with self.assertRaisesRegex(AssertionError, r"stream should be None on cpu."): + x.__dlpack__(stream=0) + + @skipMeta + @onlyCUDA + @deviceCountAtLeast(2) + def test_dlpack_tensor_on_different_device(self, devices): + dev0, dev1 = devices[:2] + + with torch.device(dev0): + x = make_tensor((5,), dtype=torch.float32, device=dev0) + + with self.assertRaisesRegex( + BufferError, r"Can't export tensors on a different CUDA device" + ): + with torch.device(dev1): + x.__dlpack__() + # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 652cd33a03538..3369b6602cfa4 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1703,27 +1703,49 @@ def __dlpack__(self, *, stream=None, max_version=None): "Can't export tensors with layout other than torch.strided" ) + if ( + self.device.type == "cuda" + and self.device.index != torch.cuda.current_device() + ): + raise BufferError( + "Can't export tensors on a different CUDA device. " + f"Expected: {self.device}. " + f"Current device: {torch.cuda.current_device()}." + ) + if stream is not None and type(stream) is not int: # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") - elif stream is not None and stream != -1: - if self.device.type == "cuda": - # NB: This logic handles the special case values for default - # streams and must be kept in sync with from_dlpack in - # torch/utils/dlpack.py - if stream == 1 and torch.version.hip is None: - stream = torch.cuda.default_stream() - elif stream == 0 and torch.version.hip is not None: - stream = torch.cuda.default_stream() - else: - stream = torch.cuda.ExternalStream(stream) - # Only synchronize on different streams - sync_stream = torch.cuda.current_stream() - if stream != sync_stream: - event = torch.cuda.Event() - event.record(sync_stream) - stream.wait_event(event) + elif self.device.type == "cuda" and stream != -1: + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + is_rocm = torch.version.hip is not None + is_cuda = not is_rocm + + if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1): + stream = torch.cuda.default_stream() + else: + if is_cuda and stream == 2: + raise BufferError("per-thread default stream is not supported.") + + device_str = "CUDA" if is_cuda else "ROCm" + assert (is_cuda and stream != 0) or ( + is_rocm and stream not in (1, 2) + ), f"unsupported stream on {device_str}: {stream}." + + stream = torch.cuda.ExternalStream(stream) + + # Only synchronize on different streams + current_stream = torch.cuda.current_stream() + if stream != current_stream: + event = torch.cuda.Event() + event.record(current_stream) + stream.wait_event(event) + elif self.device.type == "cpu": + assert stream is None, "stream should be None on cpu." + if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack From a10f15718d9f7518ce8530f6a1fd5bef83a63309 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 16:36:07 -0300 Subject: [PATCH 298/457] [DLPack] Add support for missing keyword-arguments. (#150218) This PR introduces the rest of the keyword-arguments added in DLPack version 2023.12: `dl_device` and `copy`. In summary, we handle these arguments in the C++ implementation of `to_dlpack(...)` at _torch/csrc/Module.cpp_, by calling the `maybeCopyTensor` function at _aten/src/ATen/DLConvertor.cpp_. It also introduces the following changes: - Add a new Python API `torchDeviceToDLDevice()`, which is simply a refactoring of the `getDLDevice()` function at _aten/src/ATen/DLConvertor.cpp_. - Add both keyword-arguments to the `from_dlpack()` function at _torch/utils/dlpack.py_ and to the `Tensor.__dlpack__()` dunder method. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150218 Approved by: https://github.com/albanD ghstack dependencies: #150216, #150217 --- aten/src/ATen/DLConvertor.cpp | 74 ++++++++++++++++++++++--------- aten/src/ATen/DLConvertor.h | 10 +++++ test/test_dlpack.py | 49 +++++++++++++++++++++ torch/_C/__init__.pyi.in | 15 ++++++- torch/_tensor.py | 38 +++++++++++----- torch/csrc/Module.cpp | 83 ++++++++++++++++++++++++++++++----- torch/overrides.py | 2 +- torch/utils/dlpack.py | 47 +++++++++++++++++--- 8 files changed, 266 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index f25e68001ff4d..e9a185dbfd37f 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -96,10 +96,14 @@ DLDataType getDLDataType(const Tensor& t) { return dtype; } -static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { +DLDevice torchDeviceToDLDevice(at::Device device) { DLDevice ctx; - ctx.device_id = static_cast(static_cast(device_id)); - switch (tensor.device().type()) { + + ctx.device_id = (device.is_cuda() || device.is_privateuseone()) + ? static_cast(static_cast(device.index())) + : 0; + + switch (device.type()) { case DeviceType::CPU: ctx.device_type = DLDeviceType::kDLCPU; break; @@ -120,8 +124,7 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { break; case DeviceType::XPU: ctx.device_type = DLDeviceType::kDLOneAPI; - ctx.device_id = - at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device()); + ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); break; case DeviceType::MAIA: ctx.device_type = DLDeviceType::kDLMAIA; @@ -130,38 +133,40 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { ctx.device_type = DLDeviceType::kDLExtDev; break; default: - TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); + TORCH_CHECK(false, "Cannot pack tensors on " + device.str()); } + return ctx; } -static Device getATenDevice(const DLDevice& ctx, void* data) { - switch (ctx.device_type) { +static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { + switch (type) { case DLDeviceType::kDLCPU: return at::Device(DeviceType::CPU); #ifndef USE_ROCM // if we are compiled under HIP, we cannot do cuda case DLDeviceType::kDLCUDA: - return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); + return at::Device(DeviceType::CUDA, index); #endif case DLDeviceType::kDLOpenCL: - return at::Device(DeviceType::OPENCL, static_cast(ctx.device_id)); + return at::Device(DeviceType::OPENCL, index); case DLDeviceType::kDLROCM: #ifdef USE_ROCM // this looks funny, we need to return CUDA here to masquerade - return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); + return at::Device(DeviceType::CUDA, index); #else - return at::Device(DeviceType::HIP, static_cast(ctx.device_id)); + return at::Device(DeviceType::HIP, index); #endif case DLDeviceType::kDLOneAPI: + TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data."); return at::detail::getXPUHooks().getDeviceFromPtr(data); case DLDeviceType::kDLMAIA: - return at::Device(DeviceType::MAIA, static_cast(ctx.device_id)); + return at::Device(DeviceType::MAIA, index); case DLDeviceType::kDLExtDev: - return at::Device(DeviceType::PrivateUse1, static_cast(ctx.device_id)); + return at::Device(DeviceType::PrivateUse1, index); default: TORCH_CHECK( - false, "Unsupported device_type: ", std::to_string(ctx.device_type)); + false, "Unsupported device_type: ", std::to_string(type)); } } @@ -314,11 +319,7 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); - c10::DeviceIndex device_id = 0; - if (src.is_cuda() || src.is_privateuseone()) { - device_id = src.get_device(); - } - atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); + atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); @@ -346,7 +347,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { } DLTensor& dl_tensor = src->dl_tensor; - Device device = getATenDevice(dl_tensor.device, dl_tensor.data); + Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { @@ -388,4 +389,35 @@ Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function(src, std::move(deleter)); } +Tensor maybeCopyTensor( + const Tensor& data, + std::optional optional_dl_device, + std::optional copy) { + bool force_copy = copy.has_value() && *copy; + bool force_move = copy.has_value() && !*copy; + + if (optional_dl_device.has_value()) { + auto device = at::getATenDevice( + optional_dl_device->device_type, + static_cast(optional_dl_device->device_id)); + + if (device != data.device()) { + TORCH_CHECK_VALUE( + !force_move, + "cannot move (i.e. copy=False) tensor from ", + data.device(), + " to ", + device, + " without copying."); + return data.to(device); + } + } + + if (force_copy) { + return data.clone(); + } + + return data; +} + } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index e9cbd94dfd724..b1c2eaa2d6eae 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -21,6 +21,16 @@ TORCH_API Tensor fromDLPackVersioned( TORCH_API DLDataType getDLDataType(const Tensor& t); TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); +// Copies the Tensor if there's a device mismatch or copy is forced. +// This should be used before actually creating the DLPack capsule. +TORCH_API Tensor maybeCopyTensor( + const Tensor& data, + std::optional optional_dl_device, + std::optional copy); + +// Converts the given at::Device into a DLDevice. +TORCH_API DLDevice torchDeviceToDLDevice(at::Device device); + // This trait class is used for retrieving different attributes, such as the // PyCapsule names and conversion functions for both DLPack tensor classes: // `DLManagedTensor` and `DLManagedTensorVersioned`. diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 3437f16fdc740..cb079dcb416d5 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -410,6 +410,55 @@ def test_numpy_dlpack_protocol_conversion(self, device, dtype): self.assertEqual(t, res) self.assertEqual(t.data_ptr(), res.data_ptr()) + def _test_from_dlpack(self, device, out_device=None, copy=None): + if isinstance(device, str): + device = torch.device(device) + + inp = make_tensor((5,), dtype=torch.float32, device=device) + out = torch.from_dlpack(inp, device=out_device, copy=copy) + + if out_device is None: + out_device = device + if isinstance(out_device, str): + out_device = torch.device(out_device) + + self.assertEqual(inp, out) + self.assertEqual(out.device, out_device) + + # They should be moved (i.e. not copied) only if: + # (a) we are forcing move, i.e. copy=False + # (b) the output device is the same as the input one AND copy is None + if copy is False or (copy is None and device == out_device): + self.assertEqual(inp.data_ptr(), out.data_ptr()) + else: + # Otherwise, inp should be copied. + self.assertNotEqual(inp.data_ptr(), out.data_ptr()) + + @skipMeta + @onlyCUDA + def test_copy(self, device): + # Force-copy same device tensor. + self._test_from_dlpack(device, copy=True) + self._test_from_dlpack(device, out_device=device, copy=True) + # Output should be in a different device, i.e. should have been copied. + self._test_from_dlpack(device, out_device="cpu") + self._test_from_dlpack(device, out_device="cpu", copy=True) + + @skipMeta + @onlyCUDA + def test_no_copy(self, device): + # No copy, since tensor lives in the same device. + self._test_from_dlpack(device) + self._test_from_dlpack(device, copy=False) + self._test_from_dlpack(device, out_device=device) + self._test_from_dlpack(device, out_device=device, copy=False) + + @skipMeta + @onlyCUDA + def test_needs_copy_error(self, device): + with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"): + self._test_from_dlpack(device, out_device="cpu", copy=False) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1a785ef8f237a..7f88b86a7eaf2 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1301,9 +1301,20 @@ def _initCrashHandler() -> None: ... # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 -def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack -def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned +def _to_dlpack( + data: Tensor, + dl_device: tuple[IntEnum, _int] | None = None, + copy: _bool | None = None, +) -> Any: ... # THPModule_toDLPack +def _to_dlpack_versioned( + data: Tensor, + dl_device: tuple[IntEnum, _int] | None = None, + copy: _bool | None = None, +) -> Any: ... # THPModule_toDLPackVersioned def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack +def _torchDeviceToDLDevice( + device: torch.device, +) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice def _get_cpp_backtrace( frames_to_skip: _int, maximum_number_of_frames: _int, diff --git a/torch/_tensor.py b/torch/_tensor.py index 3369b6602cfa4..2e0cd6abbb11d 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1659,7 +1659,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): __torch_dispatch__ = _C._disabled_torch_dispatch_impl - def __dlpack__(self, *, stream=None, max_version=None): + def __dlpack__( + self, + *, + stream: Optional[Any] = None, + max_version: Optional[tuple[int, int]] = None, + dl_device: Optional[tuple[enum.IntEnum, int]] = None, + copy: Optional[bool] = None, + ): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ of the current tensor to be exported to other libraries. @@ -1670,22 +1677,31 @@ def __dlpack__(self, *, stream=None, max_version=None): Args: stream (integer or None): An optional Python integer representing a - pointer to a CUDA stream. The current stream is synchronized with - this stream before the capsule is created, and since the capsule - shares its storage with the tensor this make it safe to access from - both streams. If None or -1 is passed then no synchronization is performed. - If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for - synchronization. + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for + synchronization. max_version (tuple[int, int] or None): An optional Python tuple with - 2 integers, representing the maximum version the caller supports. If - None (default), PyTorch will fallback to DLPack 0.8. + 2 integers, representing the maximum version the caller supports. If + None (default), PyTorch will fallback to DLPack 0.8. + + dl_device (tuple[DLDeviceType, int] or None): An optional tuple specifying + in which device the exported DLPack capsule should be on. If None (default), + the exported DLPack capsule will be on the same device as ``self``. + + copy (bool or None): An optional boolean indicating whether or not to copy + ``self``. If None, PyTorch will copy only if necessary. """ if has_torch_function_unary(self): args = (self,) kwargs = { "stream": stream, "max_version": max_version, + "dl_device": dl_device, + "copy": copy, } return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs) @@ -1763,9 +1779,9 @@ def __dlpack__(self, *, stream=None, max_version=None): if max_version is None or max_version[0] < 1: # Fallback to the old, unversioned variant. - return torch.to_dlpack(self) + return _C._to_dlpack(self, dl_device=dl_device, copy=copy) - return _C._to_dlpack_versioned(self) + return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy) def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if has_torch_function_unary(self): diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 8aabf24c4c1b3..9497296c1a4c0 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -607,25 +607,56 @@ void DLPack_Capsule_Destructor(PyObject* data) { } template -PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) { +PyObject* THPModule_toDLPackImpl( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - auto tensor = at::DLPackTraits::toDLPack(THPVariable_Unpack(data)); + static torch::PythonArgParser parser( + {"_to_dlpack(Tensor data, *, IntArrayRef? dl_device=None, bool? copy=None)"}); + torch::ParsedArgs<3> parsed_args{}; + auto r = parser.parse(args, kwargs, parsed_args); + + TORCH_INTERNAL_ASSERT(r.idx == 0); + + auto data = r.tensor(0); + auto dl_device = r.intlist(1); + auto copy = r.toBoolOptional(2); + + // Parse the int list into a tuple. + std::optional optional_dl_device; + + if (!dl_device.empty()) { + TORCH_CHECK( + dl_device.size() == 2, + "dl_device must be either None or a tuple of ints"); + optional_dl_device = DLDevice{ + static_cast(dl_device[0]), + static_cast(dl_device[1])}; + } + + auto tensor = at::DLPackTraits::toDLPack( + at::maybeCopyTensor(data, optional_dl_device, copy)); return PyCapsule_New( tensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); + END_HANDLE_TH_ERRORS } } // namespace -static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); +static PyObject* THPModule_toDLPack( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + return THPModule_toDLPackImpl(self, args, kwargs); } static PyObject* THPModule_toDLPackVersioned( - PyObject* _unused, - PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); + PyObject* self, + PyObject* args, + PyObject* kwargs) { + return THPModule_toDLPackImpl(self, args, kwargs); } static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { @@ -636,6 +667,28 @@ static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { END_HANDLE_TH_ERRORS } +static PyObject* THPModule_torchDeviceToDLDevice( + PyObject* _unused, + PyObject* data) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPDevice_Check(data), + "torchDeviceToDLDevice: expected torch.device argument."); + auto device = reinterpret_cast(data)->device; + auto dl_device = at::torchDeviceToDLDevice(device); + + auto tuple = PyTuple_New(2); + if (!tuple) { + throw python_error(); + } + + PyTuple_SET_ITEM(tuple, 0, THPUtils_packInt64(dl_device.device_type)); + PyTuple_SET_ITEM(tuple, 1, THPUtils_packInt64(dl_device.device_id)); + + return tuple; + END_HANDLE_TH_ERRORS +} + static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS size_t frames_to_skip = 0; @@ -1687,9 +1740,19 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, - {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, - {"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr}, + {"_to_dlpack", + castPyCFunctionWithKeywords(THPModule_toDLPack), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_to_dlpack_versioned", + castPyCFunctionWithKeywords(THPModule_toDLPackVersioned), + METH_VARARGS | METH_KEYWORDS, + nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, + {"_torchDeviceToDLDevice", + THPModule_torchDeviceToDLDevice, + METH_O, + nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", THModule_rename_privateuse1_backend, diff --git a/torch/overrides.py b/torch/overrides.py index 2e696b2d96e4d..046171ef6c5c6 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1512,7 +1512,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, - Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1, + Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1, Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } # fmt: skip diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 9a53ff9e84ac6..e7aeae1ba3c81 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,9 +1,10 @@ -from typing import Any +from typing import Any, Optional import torch import enum from torch._C import _to_dlpack as to_dlpack +from torch.types import Device as _Device __all__ = [ "DLDeviceType", @@ -54,7 +55,12 @@ class DLDeviceType(enum.IntEnum): # TODO: add a typing.Protocol to be able to tell Mypy that only objects with # __dlpack__ and __dlpack_device__ methods are accepted. -def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': +def from_dlpack( + ext_tensor: Any, + *, + device: Optional[_Device] = None, + copy: Optional[bool] = None +) -> 'torch.Tensor': """from_dlpack(ext_tensor) -> Tensor Converts a tensor from an external library into a ``torch.Tensor``. @@ -76,6 +82,13 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': an opaque ``PyCapsule`` instance, typically produced by a ``to_dlpack`` function or method. + device (torch.device or str or None): An optional PyTorch device + specifying where to place the new tensor. If None (default), the + new tensor will be on the same device as ``ext_tensor``. + + copy (bool or None): An optional boolean indicating whether or not to copy + ``self``. If None, PyTorch will copy only if necessary. + Examples:: >>> import torch.utils.dlpack @@ -106,20 +119,36 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): + # Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise, + # leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call. kwargs: dict[str, Any] = {} kwargs["max_version"] = (1, 0) - device = ext_tensor.__dlpack_device__() - # device is either CUDA or ROCm, we need to pass the current + if copy is not None: + kwargs["copy"] = copy + + # Parse the device parameter. + # At this moment, it can either be a torch.device or a str representing + # a torch.device, e.g. "cpu", "cuda", etc. + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device), ( + f"from_dlpack: unsupported device type: {type(device)}" + ) + kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device) + + ext_device = ext_tensor.__dlpack_device__() + # ext_device is either CUDA or ROCm, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): - stream = torch.cuda.current_stream(f'cuda:{device[1]}') + if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}') # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none - is_cuda = device[0] == DLDeviceType.kDLCUDA + is_cuda = ext_device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream @@ -134,6 +163,10 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': dlpack = ext_tensor.__dlpack__(**kwargs) else: + assert device is None and copy is None, ( + "device and copy kwargs not supported when ext_tensor is " + "already a DLPack capsule." + ) # Old versions just call the converter dlpack = ext_tensor return torch._C._from_dlpack(dlpack) From b4abf414254ed4d8779bad291dd0141097f019e7 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 19 Jul 2025 16:36:08 -0300 Subject: [PATCH 299/457] Raise `BufferError` for DLPack buffer-related errors. (#150691) This PR addresses the Array API documentation for [`__dlpack__`][1] and [`from_dlpack`][2] by making some buffer-related errors `BufferError` instead of `RuntimeError`, e.g. incompatible dtype, strides, or device. [1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html [2]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html#from-dlpack Pull Request resolved: https://github.com/pytorch/pytorch/pull/150691 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #150216, #150217, #150218 --- aten/src/ATen/DLConvertor.cpp | 32 ++++++++++++++++---------------- c10/util/Exception.h | 11 +++++++++++ test/test_dlpack.py | 31 +++++++++++++++++++++++++++---- torch/_tensor.py | 10 +++++----- torch/csrc/Exceptions.h | 1 + 5 files changed, 60 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index e9a185dbfd37f..bdb5cae907cd0 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -69,29 +69,29 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::Float8_e4m3fn: case ScalarType::Float8_e4m3fnuz: case ScalarType::Float8_e8m0fnu: - TORCH_CHECK(false, "float8 types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack"); break; case ScalarType::Float4_e2m1fn_x2: - TORCH_CHECK(false, "float4 types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack"); break; case ScalarType::QInt8: case ScalarType::QUInt8: case ScalarType::QInt32: case ScalarType::QUInt4x2: case ScalarType::QUInt2x4: - TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack"); break; case ScalarType::Bits1x8: case ScalarType::Bits2x4: case ScalarType::Bits4x2: case ScalarType::Bits8: case ScalarType::Bits16: - TORCH_CHECK(false, "Bit types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack"); break; case ScalarType::Undefined: - TORCH_CHECK(false, "Undefined is not a valid ScalarType"); + TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType"); case ScalarType::NumOptions: - TORCH_CHECK(false, "NumOptions is not a valid ScalarType"); + TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType"); } return dtype; } @@ -133,7 +133,7 @@ DLDevice torchDeviceToDLDevice(at::Device device) { ctx.device_type = DLDeviceType::kDLExtDev; break; default: - TORCH_CHECK(false, "Cannot pack tensors on " + device.str()); + TORCH_CHECK_BUFFER(false, "Cannot pack tensors on " + device.str()); } return ctx; @@ -165,14 +165,14 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat case DLDeviceType::kDLExtDev: return at::Device(DeviceType::PrivateUse1, index); default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported device_type: ", std::to_string(type)); } } ScalarType toScalarType(const DLDataType& dtype) { ScalarType stype = ScalarType::Undefined; - TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1"); + TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1"); switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { @@ -189,7 +189,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::UInt64; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); } break; @@ -208,7 +208,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Long; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kInt bits ", std::to_string(dtype.bits)); } break; @@ -224,7 +224,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Double; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -234,7 +234,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::BFloat16; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -250,7 +250,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::ComplexDouble; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -260,12 +260,12 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Bool; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); } break; default: - TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); + TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code)); } return stype; } diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 8136896d07f88..545cef5351380 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -267,6 +267,13 @@ class C10_API NotImplementedError : public Error { using Error::Error; }; +// Used in ATen for buffer-related errors, e.g. trying to create a DLPack of +// an unsupported device. These turn into BufferError when they cross to +// Python. +class C10_API BufferError : public Error { + using Error::Error; +}; + // Used in ATen for non finite indices. These turn into // ExitException when they cross to Python. class C10_API EnforceFiniteError : public Error { @@ -635,6 +642,10 @@ namespace c10::detail { #define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) +// Like TORCH_CHECK, but raises BufferError instead of Errors. +#define TORCH_CHECK_BUFFER(cond, ...) \ + TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__) + #define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \ TORCH_CHECK_WITH_MSG( \ ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index cb079dcb416d5..f3272cc694768 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -20,7 +20,7 @@ skipIfTorchDynamo, TestCase, ) -from torch.utils.dlpack import from_dlpack, to_dlpack +from torch.utils.dlpack import DLDeviceType, from_dlpack, to_dlpack # Wraps a tensor, exposing only DLPack methods: @@ -304,21 +304,21 @@ def test_dlpack_tensor_on_different_device(self, devices): @skipMeta def test_dlpack_export_requires_grad(self): x = torch.zeros(10, dtype=torch.float32, requires_grad=True) - with self.assertRaisesRegex(RuntimeError, r"require gradient"): + with self.assertRaisesRegex(BufferError, r"require gradient"): x.__dlpack__() @skipMeta def test_dlpack_export_is_conj(self): x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) y = torch.conj(x) - with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): + with self.assertRaisesRegex(BufferError, r"conjugate bit"): y.__dlpack__() @skipMeta def test_dlpack_export_non_strided(self): x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) y = torch.conj(x) - with self.assertRaisesRegex(RuntimeError, r"strided"): + with self.assertRaisesRegex(BufferError, r"strided"): y.__dlpack__() @skipMeta @@ -459,6 +459,29 @@ def test_needs_copy_error(self, device): with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"): self._test_from_dlpack(device, out_device="cpu", copy=False) + @skipMeta + @onlyNativeDeviceTypes + def test_unsupported_device_error(self, device): + inp = make_tensor((5,), dtype=torch.float32, device=device) + dl_device_type = DLDeviceType.kDLHexagon + + with self.assertRaisesRegex( + BufferError, f"Unsupported device_type: {int(dl_device_type)}" + ): + inp.__dlpack__(max_version=(1, 0), dl_device=(dl_device_type, 0)) + + @skipMeta + @onlyCPU + def test_dlpack_unsupported_dtype_error(self, device): + inp = make_tensor((5,), dtype=torch.float32, device=device).to( + torch.float8_e4m3fn + ) + + with self.assertRaisesRegex( + BufferError, ".* types are not supported by dlpack" + ): + from_dlpack(inp) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/torch/_tensor.py b/torch/_tensor.py index 2e0cd6abbb11d..dd9d987eea66e 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1709,13 +1709,13 @@ def __dlpack__( # so we prohibit exporting tensors that would lose their properties like # requires_grad and having the conjugate bit set. if self.requires_grad: - raise RuntimeError( + raise BufferError( "Can't export tensors that require gradient, use tensor.detach()" ) if self.is_conj(): - raise RuntimeError("Can't export tensors with the conjugate bit set") + raise BufferError("Can't export tensors with the conjugate bit set") if self.layout != torch.strided: - raise RuntimeError( + raise BufferError( "Can't export tensors with layout other than torch.strided" ) @@ -1724,8 +1724,8 @@ def __dlpack__( and self.device.index != torch.cuda.current_device() ): raise BufferError( - "Can't export tensors on a different CUDA device. " - f"Expected: {self.device}. " + "Can't export tensors on a different CUDA device index. " + f"Expected: {self.device.index}. " f"Current device: {torch.cuda.current_device()}." ) diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 8df6ea24a4bb1..60a7bb644df01 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -74,6 +74,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \ _CATCH_GENERIC_ERROR( \ NotImplementedError, PyExc_NotImplementedError, retstmnt) \ + _CATCH_GENERIC_ERROR(BufferError, PyExc_BufferError, retstmnt) \ _CATCH_GENERIC_ERROR(SyntaxError, PyExc_SyntaxError, retstmnt) \ _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \ _CATCH_GENERIC_ERROR( \ From 4869f7117009fb99a57482fce56b00c6163fbce6 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 20 Jul 2025 01:36:23 +0000 Subject: [PATCH 300/457] don't set CUDA_MODULE_LOADING (#158712) If needed, it'll be set in `_C._cuda_init()`. setenv is not threadsafe, so this can cause segfaults due to getenv/setenv races. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158712 Approved by: https://github.com/eqy --- test/test_cuda.py | 5 ----- torch/cuda/__init__.py | 2 -- 2 files changed, 7 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 581c11c85ec10..e4b5cf51b6f79 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -6484,11 +6484,6 @@ def test_cuda_autocast_deprecated_warning(self): with torch.cuda.amp.autocast(): _ = torch.ones(10) - def test_cuda_module_loading_env(self): - torch.cuda.init() - val = os.environ.get("CUDA_MODULE_LOADING", "") - self.assertEqual(val, "LAZY") - @unittest.skipIf( os.environ.get("USE_LEGACY_DRIVER", None) == "1", "Doesn't work with older driver" diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index fd88e199a7a15..4bc4c0ac3f183 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -379,8 +379,6 @@ def _lazy_init(): ) # This function throws if there's a driver initialization error, no GPUs # are found or any other error occurs - if "CUDA_MODULE_LOADING" not in os.environ: - os.environ["CUDA_MODULE_LOADING"] = "LAZY" torch._C._cuda_init() # Some of the queued calls may reentrantly call _lazy_init(); # we need to just return without initializing in that case. From badf0020144a6c00ebe7a1cdbeb74f716d48968a Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Sun, 20 Jul 2025 14:57:46 +0000 Subject: [PATCH 301/457] [Reland] Add warning about removed sm50 and sm60 arches (#158700) Related to https://github.com/pytorch/pytorch/issues/157517 Detect when users are executing torch build with cuda 12.8/12.9 and running on Maxwell or Pascal architectures. We would like to include reference to the issue: https://github.com/pytorch/pytorch/issues/157517 as well as ask people to install CUDA 12.6 builds if they are running on sm50 or sm60 architectures. Test: ``` >>> torch.cuda.get_arch_list() ['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120', 'compute_120'] >>> torch.cuda.init() /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:263: UserWarning: Found which is of cuda capability 5.0. PyTorch no longer supports this GPU because it is too old. The minimum cuda capability supported by this library is 7.0. warnings.warn( /home/atalman/.conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:268: UserWarning: Support for Maxwell and Pascal architectures is removed for CUDA 12.8+ builds. Please see https://github.com/pytorch/pytorch/issues/157517 Please install CUDA 12.6 builds if you require Maxwell or Pascal support. ``` Please note I reverted original PR https://github.com/pytorch/pytorch/pull/158301 because it broke internal users. This is a reland, added added check for non empty torch.cuda.get_arch_list() Pull Request resolved: https://github.com/pytorch/pytorch/pull/158700 Approved by: https://github.com/huydhn, https://github.com/Skylion007, https://github.com/eqy --- torch/cuda/__init__.py | 60 +++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4bc4c0ac3f183..01bc4d73a4595 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -244,21 +244,27 @@ def _extract_arch_version(arch_string: str) -> int: def _check_capability(): - incorrect_binary_warn = """ - Found GPU%d %s which requires CUDA_VERSION >= %d to - work properly, but your PyTorch was compiled - with CUDA_VERSION %d. Please install the correct PyTorch binary - using instructions from https://pytorch.org - """ # noqa: F841 - - old_gpu_warn = """ + incompatible_gpu_warn = """ Found GPU%d %s which is of cuda capability %d.%d. - PyTorch no longer supports this GPU because it is too old. - The minimum cuda capability supported by this library is %d.%d. + Minimum and Maximum cuda capability supported by this version of PyTorch is + (%d.%d) - (%d.%d) """ + matched_cuda_warn = """ + Please install PyTorch with a following CUDA + configurations: {} following instructions at + https://pytorch.org/get-started/locally/ + """ + + # Binary CUDA_ARCHES SUPPORTED by PyTorch + CUDA_ARCHES_SUPPORTED = { + "12.6": {"min": 50, "max": 90}, + "12.8": {"min": 70, "max": 120}, + "12.9": {"min": 70, "max": 120}, + } - if torch.version.cuda is not None: # on ROCm we don't want this check - CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841 + if ( + torch.version.cuda is not None and torch.cuda.get_arch_list() + ): # on ROCm we don't want this check for d in range(device_count()): capability = get_device_capability(d) major = capability[0] @@ -267,13 +273,35 @@ def _check_capability(): current_arch = major * 10 + minor min_arch = min( (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), - default=35, + default=50, ) - if current_arch < min_arch: + max_arch = max( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=50, + ) + if current_arch < min_arch or current_arch > max_arch: warnings.warn( - old_gpu_warn - % (d, name, major, minor, min_arch // 10, min_arch % 10) + incompatible_gpu_warn + % ( + d, + name, + major, + minor, + min_arch // 10, + min_arch % 10, + max_arch // 10, + max_arch % 10, + ) ) + matched_arches = "" + for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): + if ( + current_arch >= arch_info["min"] + and current_arch <= arch_info["max"] + ): + matched_arches += f" {arch}" + if matched_arches != "": + warnings.warn(matched_cuda_warn.format(matched_arches)) def _check_cubins(): From 5e149a64822fb3fdb2f1e28b947a056a64d306c5 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Sun, 20 Jul 2025 17:02:01 +0000 Subject: [PATCH 302/457] Add deprecation warning (#158203) Summary: export_for_training exist because we couldn't migrate internal usages of export to the final IR. Now that we have completed the migration, we should deprecate and delete this API. Test Plan: CI Rollback Plan: Differential Revision: D78240836 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158203 Approved by: https://github.com/JacobSzwejbka --- torch/export/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 6c3c2b6f93778..c36056d22dd58 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -10,6 +10,7 @@ from collections.abc import Iterator from enum import auto, Enum from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import deprecated import torch import torch.utils._pytree as pytree @@ -73,6 +74,11 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +@deprecated( + "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. " + "Please use `torch.export.export` instead, which is functionally equivalent.", + category=FutureWarning, +) def export_for_training( mod: torch.nn.Module, args: tuple[Any, ...], From 2e038793ef90567cc46e10ff2ca25c4a379428ab Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Sun, 20 Jul 2025 22:07:32 +0000 Subject: [PATCH 303/457] [inductor][templates] Finalize all registered hooks (#157270) This refactor ensures all registered template hooks have been finalised before accessing the code object of the template. In `simd.SimdScheduling.codegen_template` the template hooks are finalised manually with `template.finalize_hook(hook_name)` calls, so it is the responsibility of the caller to finalise all the template hooks. This PR adds: - `RenderPartial.finalize_remaining` a function that can be called at the end to finalise the remaining active hooks after a selection of hooks have been finalised manually. - A test with a custom template implementation that registers custom hooks that the scheduler needs to finalise. This test should fail if the scheduler does not finalise the registered custom hook. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157270 Approved by: https://github.com/eellison --- test/inductor/test_select_algorithm.py | 159 ++++++++++++++++++++++++- torch/_inductor/codegen/simd.py | 13 +- torch/_inductor/select_algorithm.py | 39 +++++- 3 files changed, 199 insertions(+), 12 deletions(-) diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 66781b4d7622b..d2cd77fe5cd29 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -1,5 +1,8 @@ # Owner(s): ["module: inductor"] +import contextlib import functools +import unittest.mock +from typing import Callable from unittest.mock import patch import torch @@ -9,11 +12,25 @@ import torch.nn.functional as F from torch._dynamo.testing import expectedFailureDynamicWrapper from torch._dynamo.utils import counters +from torch._inductor import config from torch._inductor.autotune_process import TritonBenchmarkRequest +from torch._inductor.ir import FixedLayout +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + PartialRender, + TritonTemplate, + TritonTemplateKernel, +) from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import is_big_gpu +from torch._inductor.utils import is_big_gpu, run_and_get_kernels +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_GPU, + requires_gpu, + requires_triton, +) aten = torch.ops.aten @@ -402,6 +419,144 @@ def test_TritonTemplateCaller_str(self): self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)") +@contextlib.contextmanager +def patch_lowering(lowering_overrides) -> Callable[[], None]: + import torch._inductor.lowering as inductor_lowering + + with unittest.mock.patch.dict(inductor_lowering.lowerings): + for fn, ( + decomp_fn, + broadcast, + type_promotion_kind, + convert_input_to_bool, + ) in lowering_overrides.items(): + inductor_lowering._register_lowering( + fn, + decomp_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + lowering_dict=inductor_lowering.lowerings, + ) + + yield + + +class TestTemplateRender(TestCase): + @requires_gpu() + @requires_triton() + @config.patch(cuda_backend="triton") + def test_finalized_subclass_hooks(self): + """ + Tests that all registered triton template hooks have been finalized, + especially in the case that the hooks are finalized manually by the + caller i.e. by calling template.finalize_hook(hook_name) + """ + hook_identifier = "# CUSTOM_HOOK" + + class ExtensionTritonTemplateKernel(TritonTemplateKernel): + def custom_hook(self) -> str: + """ + Custom hook that just returns a test string for + validation + """ + + def hook() -> str: + return hook_identifier + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render( + self, template, kwargs, record_input_dependent_tracked_event=False + ): + if record_input_dependent_tracked_event: + self.cached_replay_events = [] + + template_env = { + fn.__name__: self.record_input_dependent_tracked_event()(fn) + if record_input_dependent_tracked_event + else fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.load_input, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + # This function registers a hook that the scheduler does + # not directly finalize + self.custom_hook, + ] + } + return PartialRender( + template.render(**template_env, **kwargs), + self.render_hooks, + ) + + class ExtensionTritonTemplate(TritonTemplate): + kernel_type = ExtensionTritonTemplateKernel + + add_template = ExtensionTritonTemplate( + name="add", + grid=lambda *args, **kwargs: (1, 1, 1), + source=( + r""" +{{def_kernel("A", "B")}} + {{custom_hook()}} + xoffset = tl.program_id(0) + xindex = xoffset + tl.arange(0, XBLOCK) + xmask = tl.full([XBLOCK], True, tl.int1) + tmp0 = tl.load(A + xindex) + tmp1 = tl.load(B + xindex) + tmp2 = tmp0 + tmp1 + {{store_output(("xindex",), "tmp2", mask="xmask")}} + """ + ), + ) + + XBLOCK = 32 + + def add_override(a, b, alpha=None): + layout = FixedLayout(a.get_device(), a.get_dtype(), a.get_size()) + choices = [] + add_template.maybe_append_choice( + choices, + input_nodes=(a, b), + layout=layout, + num_stages=1, + num_warps=2, + XBLOCK=XBLOCK, + ) + return autotune_select_algorithm("add", choices, [a, b], layout) + + with patch_lowering( + { + torch.ops.aten.add.Tensor: ( + add_override, + True, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + False, + ) + } + ): + + @torch.compile + def add(a, b): + return a + b + + a = torch.zeros((XBLOCK,), device=GPU_TYPE) + b = torch.zeros((XBLOCK,), device=GPU_TYPE) + + _result, kernels = run_and_get_kernels(add, a, b) + assert len(kernels) == 1 + assert hook_identifier in kernels[0] + + if __name__ == "__main__": if IS_LINUX and HAS_GPU and is_big_gpu(): run_tests() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 42c9a9d89eb99..7ac967bbe0b03 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1639,11 +1639,16 @@ def _codegen_single_template( partial_code.finalize_hook(subgraph_name, strict=False) with kernel.set_subgraph_body(""): - if isinstance(partial_code, str): - src_code = partial_code - else: + if not isinstance(partial_code, str): partial_code.finalize_hook("") - src_code = partial_code.code + + if isinstance(partial_code, str): + src_code = partial_code + else: + # Ensure all hooks are finalized before the kernel is defined. + # Note: some of these hooks may have been registered by a kernel subclass + src_code = partial_code.finalize_remaining() + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] if config.benchmark_kernel: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index c7c49333d1aea..c316f0d4bc7ef 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -167,11 +167,25 @@ class PartialRender: of replacements after the initial render. """ + FINALIZED_HOOK: object = object() + def __init__(self, code, replacement_hooks) -> None: super().__init__() - self.code = code + self._code = code self.replacement_hooks = replacement_hooks + @property + def code(self): + remaining_active_hooks = [ + key + for key, fn in self.replacement_hooks.items() + if fn is not self.FINALIZED_HOOK + ] + assert len(remaining_active_hooks) == 0, ( + f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}" + ) + return self._code + def finalize_hook(self, hook_key: str, strict=True) -> None: if hook_key not in self.replacement_hooks: if strict: @@ -180,15 +194,28 @@ def finalize_hook(self, hook_key: str, strict=True) -> None: ) else: return - assert self.replacement_hooks[hook_key] is not None, ( + assert self.replacement_hooks[hook_key] is not self.FINALIZED_HOOK, ( "hook_key can only be called once" ) - self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) - self.replacement_hooks[hook_key] = None + self._code = self._code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = self.FINALIZED_HOOK - def finalize_all(self) -> str: + def finalize_remaining(self) -> str: + """ + Finalize the remaining active hooks. This function can be used in cases + where the caller uses `finalize_hook` rather than `finalize_all`. + Note: `finalize_all` errors if a hook that has already been finalized + is attempted to be called again. This function only attempts to + finalize active hooks. + """ for key, fn in self.replacement_hooks.items(): - self.code = self.code.replace(key, fn()) + if fn is not self.FINALIZED_HOOK: + self.finalize_hook(key) + return self.code + + def finalize_all(self) -> str: + for key in self.replacement_hooks: + self.finalize_hook(key) return self.code From 4b02bd76d3e9a74609d6fcf7a749801ad253916d Mon Sep 17 00:00:00 2001 From: Ankita George Date: Sun, 20 Jul 2025 22:52:54 +0000 Subject: [PATCH 304/457] DCP safetensors test fix (#158685) https://github.com/pytorch/pytorch/pull/158069 removed the consolidated output path argument without updating the test. Reported by a user here https://github.com/pytorch/pytorch/pull/156705#issuecomment-3090748034. Adding back the logic from the original PR https://github.com/pytorch/pytorch/pull/158069 and fixing the test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158685 Approved by: https://github.com/teja-rao --- .../checkpoint/test_consolidate_hf_safetensors.py | 2 +- .../checkpoint/test_hf_safetensor_e2e.py | 8 ++------ torch/distributed/checkpoint/_hf_utils.py | 2 ++ torch/distributed/checkpoint/hf_storage.py | 14 +++++++------- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py index c1686142fd8e8..ba07c62728d71 100644 --- a/test/distributed/checkpoint/test_consolidate_hf_safetensors.py +++ b/test/distributed/checkpoint/test_consolidate_hf_safetensors.py @@ -62,7 +62,7 @@ def _create_d_tensors(self) -> None: dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.HuggingFaceStorageWriter( - path=self.temp_dir, save_sharded=True + path=self.temp_dir, save_distributed=True ), ) dist.barrier() diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index 0220ae5138fc1..92f9b97237064 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -151,8 +151,6 @@ def test_consolidate_to_one_file(self) -> None: global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) checkpoint_dir = self.temp_dir - consolidated_output_dir = os.path.join(checkpoint_dir, "consolidated") - os.makedirs(consolidated_output_dir, exist_ok=True) state_dict_to_save = {"dtensor": dtensor} dist_cp.save( @@ -160,15 +158,13 @@ def test_consolidate_to_one_file(self) -> None: storage_writer=dist_cp.HuggingFaceStorageWriter( path=checkpoint_dir, save_distributed=True, - consolidated_output_path=consolidated_output_dir, + enable_consolidation=True, ), ) dist.barrier() if self.rank == 0: - file_path = os.path.join( - consolidated_output_dir, "model-00001-of-00001.safetensors" - ) + file_path = os.path.join(checkpoint_dir, "model-00001-of-00001.safetensors") loaded_dict = safetensors.torch.load_file(file_path) self.assertEqual(loaded_dict.keys(), {"dtensor"}) self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) diff --git a/torch/distributed/checkpoint/_hf_utils.py b/torch/distributed/checkpoint/_hf_utils.py index 84d4affe6c569..1a3f627fd69b5 100644 --- a/torch/distributed/checkpoint/_hf_utils.py +++ b/torch/distributed/checkpoint/_hf_utils.py @@ -43,6 +43,8 @@ NUM_BYTES_FOR_HEADER_LEN = 8 +SHARDED_DIR_NAME = "sharded" + @dataclass class _HFStorageInfo: diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 4e97a3e02e328..13fd61910dd21 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -23,6 +23,7 @@ DTYPE_KEY, SAVED_OFFSETS_KEY, SHAPE_KEY, + SHARDED_DIR_NAME, SUFFIX, ) from torch.distributed.checkpoint.filesystem import SerializationFormat @@ -85,10 +86,8 @@ def __init__( token: The token to use to authenticate with huggingface hub. save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard. Default is False which assumes rank-0 checkpointing of the full state_dict. - enable_consolidation: If True, consolidate the sharded checkpoint after saving. Default to False. - consolidated_output_path: If provided, the output path where the consolidated files will be written in the finish step. - If enable_consolidation is True and this is not provided the consolidated files - will be written to `path`. + enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be + saved to path/sharded and the full tensors will be saved to path. Default to False. thread_count_consolidation: Number of threads to use for parallel processing of saving data to consolidated output files. Default to 1. """ @@ -109,9 +108,10 @@ def __init__( self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping self.save_distributed: bool = save_distributed self.enable_consolidation: bool = enable_consolidation - self.consolidated_output_path: str = ( - consolidated_output_path if consolidated_output_path is not None else path - ) + self.consolidated_output_path: Optional[str] = None + if self.enable_consolidation: + self.consolidated_output_path = str(self.path) + self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME) self.thread_count_consolidation = thread_count_consolidation def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: From 2cdafab0bd1510e4bd286f33fd94807c59c7e691 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 20 Jul 2025 23:49:18 +0000 Subject: [PATCH 305/457] [BE] Raise ValueError from `torch.cat` meta func (#158249) Followup after https://github.com/pytorch/pytorch/pull/155460 From [Python documentation](https://docs.python.org/3/library/exceptions.html#ValueError): > Raised when an operation or function receives an argument that has the right type but an inappropriate value, and the situation is not described by a more precise exception such as IndexError. Raise [`TypeError`](https://docs.python.org/3/library/exceptions.html#TypeError) when input-output types are incompatible with each other > Raised when an operation or function is applied to an object of inappropriate type. The associated value is a string giving details about the type mismatch. > This exception may be raised by user code to indicate that an attempted operation on an object is not supported, and is not meant to be. If an object is meant to support a given operation but has not yet provided an implementation, [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) is the proper exception to raise. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158249 Approved by: https://github.com/jbschlosser, https://github.com/Skylion007, https://github.com/albanD --- aten/src/ATen/native/TensorShape.cpp | 6 +++--- test/test_ops.py | 17 ++++++++++++++++- test/test_type_promotion.py | 2 +- .../_internal/common_methods_invocations.py | 2 +- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 340ee49bffa8f..c2d0856c3cd4c 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -247,7 +247,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // Checking names before the actual dimensions. auto maybe_outnames = namedinference::compute_cat_outnames(materialized); - TORCH_CHECK( + TORCH_CHECK_VALUE( !materialized.empty(), "torch.cat(): expected a non-empty list of Tensors"); @@ -274,7 +274,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // when computing the actual output dtype and the flags. if (is_out_defined) { // Check for type promotion, if the output tensor is defined. - TORCH_CHECK( + TORCH_CHECK_TYPE( canCast(out_dtype, result.scalar_type()), "torch.cat(): input types can't be cast to the desired output type ", result.scalar_type()); @@ -293,7 +293,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // are compatible, i.e. we can execute `cat` on them. bool found_valid_tensor = valid < materialized.size(); if (found_valid_tensor) { - TORCH_CHECK( + TORCH_CHECK_INDEX( dim <= materialized[valid].get().dim(), "torch.cat(): dimension ", dim, diff --git a/test/test_ops.py b/test/test_ops.py index 26f8865b3a00f..201b0323a86fd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1109,7 +1109,22 @@ def _case_four_transform(t): if op.is_factory_function and sample.kwargs.get("dtype", None) is None: op_out(out=out) else: - with self.assertRaises(RuntimeError, msg=msg_fail): + # TODO: Remove me when all ops will raise type error on mismatched types + exc_type = ( + TypeError + if op.name + in [ + "_chunk_cat", + "cat", + "column_stack", + "dstack", + "hstack", + "vstack", + "stack", + ] + else RuntimeError + ) + with self.assertRaises(exc_type, msg=msg_fail): op_out(out=out) @ops( diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 88fcdd3a5dcaf..59d856ec4fc9f 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -1046,7 +1046,7 @@ def test_cat_out_different_dtypes(self, device): and not (out_dtype.is_floating_point or out_dtype.is_complex)) or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)): # This combinations do not support type conversion to a different class out type - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): torch.cat([x, y], out=out) else: torch.cat([x, y], out=out) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92ae95bef8d0e..85a333a566012 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2456,7 +2456,7 @@ def error_inputs_cat(op_info, device, **kwargs): # error inputs for empty tensors yield ErrorInput(SampleInput([], kwargs={'dim': 1}), - error_regex='non-empty list of Tensors') + error_regex='non-empty list of Tensors', error_type=ValueError) # error inputs for different sizes yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), From ff0da08f4bc5ee135b495926cd58a36a1c0e1a5b Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 21 Jul 2025 01:08:59 +0000 Subject: [PATCH 306/457] [AOTI] normalize path and process model files. (#158705) Continued to https://github.com/pytorch/pytorch/pull/158702 , split `zip_filename_str` and real file path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158705 Approved by: https://github.com/desertfire --- .../aoti_package/model_package_loader.cpp | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 8e3a2d95fb9ec..66568025718af 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -478,27 +478,31 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::string so_filename; std::string cpp_filename; std::vector obj_filenames; - std::string model_directory = file_prefix + "data" + k_separator + - "aotinductor" + k_separator + model_name; - std::string const_directory = - file_prefix + "data" + k_separator + "constants"; - - for (const std::string& filename_str : found_filenames) { + std::string model_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "aotinductor" + k_separator + + model_name); + std::string const_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "constants"); + + // zip_filename_str can't be normalize_path_separator, because it should be + // as index for mz_zip_reader_extract_file_to_file. + for (auto zip_filename_str : found_filenames) { + auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory - if (c10::starts_with(filename_str, model_directory) || - c10::starts_with(filename_str, const_directory)) { + if (c10::starts_with(cur_filename, model_directory) || + c10::starts_with(cur_filename, const_directory)) { std::string output_path_str = temp_dir_; - if (c10::starts_with(filename_str, model_directory)) { + if (c10::starts_with(cur_filename, model_directory)) { output_path_str += k_separator; - output_path_str += filename_str; - } else { // startsWith(filename_str, const_directory) + output_path_str += cur_filename; + } else { // startsWith(zip_filename_str, const_directory) // Extract constants to the same directory as the rest of the files // to be consistent with internal implementation - size_t lastSlash = filename_str.find_last_of(k_separator); - std::string filename = filename_str; + size_t lastSlash = cur_filename.find_last_of(k_separator); + std::string filename = cur_filename; if (lastSlash != std::string::npos) { - filename = filename_str.substr(lastSlash + 1); + filename = cur_filename.substr(lastSlash + 1); } output_path_str.append(k_separator) .append(model_directory) @@ -507,7 +511,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } std::string output_file_path = normalize_path_separator(output_path_str); - LOG(INFO) << "Extract file: " << filename_str << " to " + LOG(INFO) << "Extract file: " << zip_filename_str << " to " << output_file_path; // Create the parent directory if it doesn't exist @@ -526,10 +530,12 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // Extracts file to the temp directory mz_bool b_extract = mz_zip_reader_extract_file_to_file( - &zip_archive, filename_str.c_str(), output_file_path.c_str(), 0); + &zip_archive, zip_filename_str.c_str(), output_file_path.c_str(), 0); if (b_extract == MZ_FALSE) { throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", filename_str, output_file_path)); + "Failed to extract file {} to {}", + zip_filename_str, + output_file_path)); } // Save the file for bookkeeping From 5e1232871b641762aa1fdd84ba441a8fc9e34043 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 02:24:11 +0000 Subject: [PATCH 307/457] Revert "[build] pin `setuptools>=77` to enable PEP 639 (#158104)" This reverts commit a4ec381302f8acd279033707b182bed30ffd2091. Reverted https://github.com/pytorch/pytorch/pull/158104 on behalf of https://github.com/malfet due to This break inductor-perf-nighly-macos by failing to build torchvision, see https://github.com/pytorch/pytorch/issues/158728 ([comment](https://github.com/pytorch/pytorch/pull/158104#issuecomment-3095048940)) --- .ci/docker/manywheel/Dockerfile_2_28 | 2 +- .ci/docker/manywheel/Dockerfile_s390x | 5 +++-- .ci/docker/requirements-ci.txt | 7 +++---- .ci/pytorch/build.sh | 3 --- .ci/pytorch/win-test-helpers/build_pytorch.bat | 5 ----- .ci/pytorch/win-test.sh | 2 +- .ci/pytorch/windows/internal/install_python.bat | 2 +- .ci/pytorch/windows/setup_build.bat | 5 +---- .ci/wheel/build_wheel.sh | 14 +++++++------- .github/requirements/pip-requirements-macOS.txt | 6 +++--- .github/scripts/lintrunner.sh | 2 +- .github/scripts/windows/build_triton.bat | 2 +- pyproject.toml | 11 ++++++++--- requirements-build.txt | 4 ++-- test/dynamo/test_exc.py | 16 ++++++++-------- 15 files changed, 40 insertions(+), 46 deletions(-) diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 7f279a1c1a735..b150423e99544 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -128,7 +128,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH # Install setuptools and wheel for python 3.12/3.13 RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ - /opt/python/${cpython_version}/bin/python -m pip install "setuptools>=77.0.0" "packaging>=24.2" wheel; \ + /opt/python/${cpython_version}/bin/python -m pip install setuptools wheel; \ done; diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 335488b88f122..46ec7f77ae8ba 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -124,9 +124,10 @@ RUN python3 -mpip install cmake==3.28.0 # install newest flatbuffers version first: # for some reason old version is getting pulled in otherwise. # packaging package is required for onnxruntime wheel build. -RUN pip3 install 'setuptools>=77.0' 'packaging>=24.2' && \ - pip3 install flatbuffers cython 'pkgconfig>=1.5.5' 'numpy<2.3.0' && \ +RUN pip3 install flatbuffers && \ + pip3 install cython 'pkgconfig>=1.5.5' 'setuptools>=77' 'numpy<2.3.0' && \ pip3 install --no-build-isolation h5py==3.11.0 && \ + pip3 install packaging && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index facc633f6a7ad..fb773ff324af8 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -50,7 +50,7 @@ flatbuffers==24.12.23 hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests -#Pinned versions: 5.35.1 +#Pinned versions: 3.44.6, 4.53.2 #test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py junitparser==2.1.1 @@ -307,7 +307,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.15.1.0 +z3-solver==4.12.6.0 #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -363,10 +363,9 @@ pwlf==2.2.1 # To build PyTorch itself -packaging>=24.2 pyyaml pyzstd -setuptools>=77.0.0 +setuptools>=70.1.0 six scons==4.5.2 ; platform_machine == "aarch64" diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index f2b8998a6f6cd..58454bcb108a7 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -269,9 +269,6 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" //... fi else - # install build-system requirements before running setup.py commands - python -m pip install -r requirements-build.txt - # check that setup.py would fail with bad arguments echo "The next three invocations are expected to fail with invalid command error messages." ( ! get_exit_code python setup.py bad_argument ) diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 74c9183f2abb0..7ceb425ce2d1a 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -126,11 +126,6 @@ if "%USE_CUDA%"=="1" ( set CMAKE_CUDA_COMPILER_LAUNCHER=%TMP_DIR%/bin/randomtemp.exe;%TMP_DIR%\bin\sccache.exe ) -:: Install build-system requirements before running setup.py commands -python -m pip install -r requirements-build.txt -if errorlevel 1 goto fail -if not errorlevel 0 goto fail - :: Print all existing environment variable for debugging set diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index be7f3e4bb35cc..b61dd06ef562c 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -41,7 +41,7 @@ fi python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 # Install Z3 optional dependency for Windows builds. -python -m pip install z3-solver==4.15.1.0 +python -m pip install z3-solver==4.12.2.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.30 diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 65405a875b6b8..73622bd736edd 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -18,5 +18,5 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" -%PYTHON_EXEC% -m pip install --upgrade pip "setuptools>=77.0.0" "packaging>=24.2" wheel +%PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel if errorlevel 1 exit /b 1 diff --git a/.ci/pytorch/windows/setup_build.bat b/.ci/pytorch/windows/setup_build.bat index df925b4ba90bc..9b492eef664d7 100644 --- a/.ci/pytorch/windows/setup_build.bat +++ b/.ci/pytorch/windows/setup_build.bat @@ -7,9 +7,6 @@ call "internal\install_python.bat" %PYTHON_EXEC% --version set "PATH=%CD%\Python\Lib\site-packages\cmake\data\bin;%CD%\Python\Scripts;%CD%\Python;%PATH%" - -%PYTHON_EXEC% -m pip install "setuptools>=77.0.0" "packaging>=24.2" - if "%DESIRED_PYTHON%" == "3.13t" %PYTHON_EXEC% -m pip install numpy==2.2.1 cmake if "%DESIRED_PYTHON%" == "3.13" %PYTHON_EXEC% -m pip install numpy==2.1.2 cmake if "%DESIRED_PYTHON%" == "3.12" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake @@ -19,7 +16,7 @@ if "%DESIRED_PYTHON%" == "3.9" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake %PYTHON_EXEC% -m pip install pyyaml %PYTHON_EXEC% -m pip install mkl-include mkl-static -%PYTHON_EXEC% -m pip install boto3 ninja typing-extensions +%PYTHON_EXEC% -m pip install boto3 ninja typing_extensions setuptools==72.1.0 where cmake.exe diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index dc44f8ccc2922..878d6595c84c0 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -127,7 +127,7 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages export MACOSX_DEPLOYMENT_TARGET=10.15 export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -SETUPTOOLS_PINNED_VERSION="==77.0.0" +SETUPTOOLS_PINNED_VERSION="==70.1.0" PYYAML_PINNED_VERSION="=5.3" EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -135,7 +135,7 @@ RENAME_WHEEL=true case $desired_python in 3.13t) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" CONDA_ENV_CREATE_FLAGS="python-freethreading" @@ -145,31 +145,31 @@ case $desired_python in ;; 3.13) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" ;; 3.12) echo "Using 3.12 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.11) echo "Using 3.11 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.10) echo "Using 3.10 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.9) echo "Using 3.9 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 7929ecfe1e4bb..9c72c71523b7d 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -12,7 +12,7 @@ numba==0.59.0 numpy==1.26.4 opt-einsum>=3.3 optree==0.13.0 -packaging==25.0 +packaging==23.1 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 @@ -26,11 +26,11 @@ pytest-xdist==3.3.1 pytest==7.3.2 pyyaml==6.0.2 scipy==1.12.0 -setuptools==80.9.0 +setuptools==72.1.0 sympy==1.13.3 tlparse==0.3.30 tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 -z3-solver==4.15.1.0 +z3-solver==4.12.2.0 diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index 1411ff0397b53..ef4741444f942 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -2,7 +2,7 @@ set -ex # Use uv to speed up lintrunner init -python3 -m pip install -U uv setuptools +python3 -m pip install uv==0.1.45 setuptools CACHE_DIRECTORY="/tmp/.lintbin" # Try to recover the cached binaries diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index da2e86b40432a..97cd535a49889 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -10,7 +10,7 @@ if "%PY_VERS%" == "3.13t" ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) :: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==78.1.1 ninja +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja dir "%VC_INSTALL_PATH%" diff --git a/pyproject.toml b/pyproject.toml index 133da9289f5c9..b41ae87621f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,13 @@ [build-system] requires = [ + # 70.1.0: min version for integrated bdist_wheel command from wheel package # 77.0.0: min version for SPDX expression support for project.license - "setuptools>=77.0.0,<80.0", + "setuptools>=70.1.0,<80.0", "cmake>=3.27", "ninja", "numpy", - "packaging>=24.2", + "packaging", "pyyaml", "requests", "six", # dependency chain: NNPACK -> PeachPy -> six @@ -20,7 +21,11 @@ name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" requires-python = ">=3.9,<3.14" -license = "BSD-3-Clause" +# TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 +# FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. +# TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed +# to an error on 2026.02.18. See also: https://github.com/pypa/setuptools/issues/4903 +license = { text = "BSD-3-Clause" } authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] keywords = ["pytorch", "machine learning"] classifiers = [ diff --git a/requirements-build.txt b/requirements-build.txt index 12332b0e1af01..be19d987f73db 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,9 +1,9 @@ # Build System requirements -setuptools>=77.0.0,<80.0 # setuptools develop deprecated on 80.0 +setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 cmake>=3.27 ninja numpy -packaging>=24.2 +packaging pyyaml requests six # dependency chain: NNPACK -> PeachPy -> six diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index c340a2882d471..acc3fd55f6fb0 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -251,13 +251,13 @@ def fn(x, shape): Model: ==> L['shape'][0]: 0 - ==> L['shape'][1]: 0 - ==> L['shape'][2]: 0 + ==> L['shape'][1]: 1 + ==> L['shape'][2]: 1 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 - ==> s3: 0 - ==> s52: 0 + ==> s3: 1 + ==> s52: 1 ==> s77: 3 ==> s86: 0 @@ -315,16 +315,16 @@ def fn(x, shape): %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) Model: - ==> L['shape'][0]: 0 - ==> L['shape'][1]: 0 + ==> L['shape'][0]: 1 + ==> L['shape'][1]: 1 ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s3: 0 - ==> s52: 0 + ==> s52: 1 ==> s77: 3 - ==> s86: 0 + ==> s86: 1 Assertions: ==> (== 0 L['x'].storage_offset()) From 70b4a8880b1c3fb5e92c5fcd75bda6b6f299abac Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Sun, 20 Jul 2025 16:39:35 -0700 Subject: [PATCH 308/457] [SymmMem] Add NVSHMEM barrier_all, my_pe, n_pes support into Triton (#158511) Adds device-side barrier synchronization and PE identification functions for NVSHMEM Triton integration. Includes `barrier_all()` for collective synchronization and `my_pe()`/`n_pes()` for PE identification within kernels. We are launching with cooperative grid launch (for all the PRs in this stack) because the `nvshmemx_collective_launch` function must be used to launch kernels on the GPU when the kernels use NVSHMEM synchronization or collective APIs, and `nvshmemx_collective_launch` essentially boils down to a CUDA cooperative group launch. Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_barrier` Also tested that if you remove the barrier, you get an assertion error/race conditions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158511 Approved by: https://github.com/fduwjj --- test/distributed/test_nvshmem_triton.py | 89 +++++++++++++++++-- .../_symmetric_memory/_nvshmem_triton.py | 33 +++++++ 2 files changed, 113 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 8958ec4eb84e2..996ad25db5d7e 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -4,6 +4,8 @@ # python test/distributed/test_nvshmem_triton.py +import triton.language as tl + import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem @@ -148,6 +150,37 @@ def put_with_quiet_kernel( nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) +@triton.jit +def barrier_test_kernel( + dst_ptr, + src_ptr, + numel, +): + # Testing barrier_all() requires coordinated operations across PEs within + # the same kernel execution. Unlike other kernels that just wrap NVSHMEM + # primitives, this one implements the full test logic to properly verify + # device-side barrier synchronization. + my_pe = nvshmem.my_pe() + n_pes = nvshmem.n_pes() + # Rank 0 broadcasts its value to all other ranks + if my_pe == 0: + # Write initial value + p_src = src_ptr.to(tl.pointer_type(tl.int32)) + tl.store(p_src, 42) + # Put to all other ranks + i = 1 + while i < n_pes: + nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + i += 1 + # Synchronize all PEs + nvshmem.barrier_all() + # Non-zero ranks increment the received value + if my_pe != 0: + p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + received = tl.load(p_dst) + tl.store(p_dst, received + 1) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -172,7 +205,7 @@ def test_triton_put(self) -> None: # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -211,7 +244,7 @@ def test_triton_get(self) -> None: self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank msg_size_bytes = 8 @@ -249,7 +282,7 @@ def test_triton_get_ring(self) -> None: self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank world_size = dist.get_world_size() @@ -292,7 +325,7 @@ def test_triton_put_signal_set(self) -> None: nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -357,7 +390,7 @@ def test_triton_put_signal_add(self) -> None: nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -419,7 +452,7 @@ def test_triton_wait_until(self) -> None: self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -493,7 +526,7 @@ def test_triton_signal_wait_until(self) -> None: self._init_device() # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank peer = (self.world_size - 1) - rank @@ -569,7 +602,7 @@ def test_triton_fence(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank peer = (self.world_size - 1) - rank @@ -644,7 +677,7 @@ def test_triton_quiet(self) -> None: self._init_device() # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank msg_size_bytes = 8 @@ -694,6 +727,44 @@ def test_triton_quiet(self) -> None: extern_libs=nvshmem_lib, ) + @skipIfRocm + @requires_triton() + def test_triton_barrier(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + numel = 1 + dtype = torch.int32 + # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Launch kernel with cooperative grid + barrier_test_kernel[(1,)]( + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + numel=numel, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + num_ctas=1, + ) + # Verify results + # Rank 0 should have 42, and then the rest should have incremented + 1 to 43 + if rank == 0: + # Rank 0 should have its original value (42) in src + torch.testing.assert_close( + src, torch.tensor([42], device=self.device, dtype=dtype) + ) + else: + # Other ranks should have received 42 and incremented to 43 + torch.testing.assert_close( + dst, torch.tensor([43], device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 75abae38c755a..3cd8c99d34f36 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -197,3 +197,36 @@ def quiet(_builder=None): # type: ignore[no-untyped-def] is_pure=False, _builder=_builder, ) + + @core.extern + def my_pe(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_my_pe", core.dtype("int32"))}, + is_pure=True, + _builder=_builder, + ) + + @core.extern + def n_pes(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_n_pes", core.dtype("int32"))}, + is_pure=True, + _builder=_builder, + ) + + @core.extern + def barrier_all(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_barrier_all", core.dtype("int32"))}, + is_pure=False, + _builder=_builder, + ) From 1c6328a588d53fe7f779942e3d8c03ee45a79e80 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 20 Jul 2025 19:13:32 -0700 Subject: [PATCH 309/457] [EZ][BE] Fix compilation warning in Pooling.metal (#158729) This one ``` Compiling /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Pooling.metal to Pooling_30.air /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Pooling.metal:172:1: warning: non-void function does not return a value in all control paths [-Wreturn-type] } ^ 1 warning generated. ``` Although functionally one is not supposed to hit this codepath ever, it's not not to throw warning Pull Request resolved: https://github.com/pytorch/pytorch/pull/158729 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/mps/kernels/Pooling.metal | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal index 18982559a34b8..05ce39bd83163 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.metal +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -169,6 +169,7 @@ PoolOffsets find_pool_offsets( return_indices, tid); } + return PoolOffsets(); } // Kernel computes one element of the output per kernel call. From a527e816935957a164d74dd7c5069310b2857695 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 21 Jul 2025 13:28:36 +0800 Subject: [PATCH 310/457] [CI] update flake8 and mypy lint dependencies (#158720) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158720 Approved by: https://github.com/Skylion007 --- .ci/docker/requirements-ci.txt | 6 +- .../actions/filter-test-configs/action.yml | 2 +- .github/requirements-gha-cache.txt | 4 +- .../requirements/pip-requirements-macOS.txt | 2 +- .../workflows/check_mergeability_ghstack.yml | 2 +- .github/workflows/cherry-pick.yml | 2 +- .github/workflows/revert.yml | 2 +- .github/workflows/trymerge.yml | 2 +- .github/workflows/tryrebase.yml | 2 +- .lintrunner.toml | 28 +- tools/build/bazel/requirements.in | 2 +- tools/build/bazel/requirements.txt | 332 +++++++++--------- tools/setup_helpers/build.bzl | 2 +- torchgen/build.bzl | 4 +- 14 files changed, 195 insertions(+), 197 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index fb773ff324af8..4a52ad5e951bc 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -221,9 +221,9 @@ pygments==2.15.0 #Pinned versions: 2.12.0 #test that import: the doctests -#PyYAML +#pyyaml #Description: data serialization format -#Pinned versions: +#Pinned versions: 6.0.2 #test that import: #requests @@ -233,7 +233,7 @@ pygments==2.15.0 #rich #Description: rich text and beautiful formatting in the terminal -#Pinned versions: 10.9.0 +#Pinned versions: 14.0.0 #test that import: scikit-image==0.19.3 ; python_version < "3.10" diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index ca6643f9e2fc1..338fc0c2a844c 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -70,7 +70,7 @@ runs: set -eux # PyYAML 6.0 doesn't work with MacOS x86 anymore # This must run on Python-3.7 (AmazonLinux2) so can't use request=3.32.2 - python3 -m pip install requests==2.27.1 pyyaml==6.0.1 + python3 -m pip install requests==2.27.1 pyyaml==6.0.2 - name: Parse ref id: parse-ref diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 5c691e4bf9b31..8c4a877fdd193 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -10,6 +10,6 @@ jinja2==3.1.6 lintrunner==0.10.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 -pyyaml==6.0 +pyyaml==6.0.2 requests==2.32.4 -rich==10.9.0 +rich==14.0.0 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 9c72c71523b7d..0f8276f1dda63 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -2,7 +2,7 @@ boto3==1.35.42 cmake==3.27.* expecttest==0.3.0 fbscribelogger==0.1.7 -filelock==3.6.0 +filelock==3.18.0 hypothesis==6.56.4 librosa>=0.6.2 mpmath==1.3.0 diff --git a/.github/workflows/check_mergeability_ghstack.yml b/.github/workflows/check_mergeability_ghstack.yml index 65193839e9b9d..569a174665ba8 100644 --- a/.github/workflows/check_mergeability_ghstack.yml +++ b/.github/workflows/check_mergeability_ghstack.yml @@ -56,7 +56,7 @@ jobs: cache: pip architecture: x64 - - run: pip install pyyaml==6.0 + - run: pip install pyyaml==6.0.2 shell: bash - name: Verify mergeability diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml index 1d385b556277a..310857782ea14 100644 --- a/.github/workflows/cherry-pick.yml +++ b/.github/workflows/cherry-pick.yml @@ -26,7 +26,7 @@ jobs: cache: pip # Not the direct dependencies but the script uses trymerge - - run: pip install pyyaml==6.0 + - run: pip install pyyaml==6.0.2 - name: Setup committer id run: | diff --git a/.github/workflows/revert.yml b/.github/workflows/revert.yml index 3c8722930e22e..226d773e48977 100644 --- a/.github/workflows/revert.yml +++ b/.github/workflows/revert.yml @@ -26,7 +26,7 @@ jobs: architecture: x64 check-latest: false cache: pip - - run: pip install pyyaml==6.0 + - run: pip install pyyaml==6.0.2 - name: Setup committer id run: | diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 19e169bd973b3..1fdb1da67a595 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -28,7 +28,7 @@ jobs: check-latest: false cache: pip architecture: x64 - - run: pip install pyyaml==6.0 + - run: pip install pyyaml==6.0.2 - name: Setup committer id run: | diff --git a/.github/workflows/tryrebase.yml b/.github/workflows/tryrebase.yml index 9af59bcb3662d..1a8e00e4390be 100644 --- a/.github/workflows/tryrebase.yml +++ b/.github/workflows/tryrebase.yml @@ -25,7 +25,7 @@ jobs: architecture: x64 check-latest: false cache: pip - - run: pip install pyyaml==6.0 + - run: pip install pyyaml==6.0.2 - name: Setup committer id run: | diff --git a/.lintrunner.toml b/.lintrunner.toml index 04664378d8bf8..6cc1164a785dd 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -39,16 +39,16 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'flake8==6.1.0', - 'flake8-bugbear==23.3.23', - 'flake8-comprehensions==3.15.0', + 'flake8==7.3.0', + 'flake8-bugbear==24.12.12', + 'flake8-comprehensions==3.16.0', 'flake8-executable==2.1.3', - 'flake8-logging-format==0.9.0', - 'flake8-pyi==23.3.1', - 'flake8-simplify==0.19.3', + 'flake8-logging-format==2024.24.12', + 'flake8-pyi==25.5.0', + 'flake8-simplify==0.22.0', 'mccabe==0.7.0', - 'pycodestyle==2.11.1', - 'pyflakes==3.1.0', + 'pycodestyle==2.14.0', + 'pyflakes==3.4.0', 'torchfix==0.4.0 ; python_version >= "3.9" and python_version < "3.13"', ] @@ -158,16 +158,16 @@ init_command = [ 'mypy==1.16.0', 'sympy==1.13.3', 'types-requests==2.27.25', - 'types-pyyaml==6.0.1', + 'types-pyyaml==6.0.2', 'types-tabulate==0.8.8', 'types-protobuf==5.29.1.20250403', 'types-setuptools==79.0.0.20250422', 'types-jinja2==2.11.9', 'types-colorama==0.4.6', - 'filelock==3.13.1', + 'filelock==3.18.0', 'junitparser==2.1.1', - 'rich==10.9.0', - 'pyyaml==6.0.1', + 'rich==14.0.0', + 'pyyaml==6.0.2', 'optree==0.13.0', 'dataclasses-json==0.6.7', 'pandas==2.2.3', @@ -1111,7 +1111,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'PyYAML==6.0.1', + 'pyyaml==6.0.2', ] [[linter]] @@ -1133,7 +1133,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'PyYAML==6.0.1', + 'pyyaml==6.0.2', ] [[linter]] diff --git a/tools/build/bazel/requirements.in b/tools/build/bazel/requirements.in index 37750163da81e..8837501006624 100644 --- a/tools/build/bazel/requirements.in +++ b/tools/build/bazel/requirements.in @@ -1,4 +1,4 @@ -PyYAML==6.0.1 +pyyaml==6.0.2 numpy==1.26.4 requests==2.32.2 setuptools==78.1.1 diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index a15924660167d..a3383b60c1964 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -1,108 +1,106 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --allow-unsafe --generate-hashes tools/build/bazel/requirements.in -# -certifi==2024.7.4 \ - --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ - --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 +# This file was autogenerated by uv via the following command: +# uv pip compile --generate-hashes tools/build/bazel/requirements.in +certifi==2025.7.14 \ + --hash=sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2 \ + --hash=sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995 # via requests -charset-normalizer==3.3.2 \ - --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ - --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ - --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ - --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ - --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ - --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ - --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ - --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ - --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ - --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ - --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ - --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ - --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ - --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ - --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ - --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ - --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ - --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ - --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ - --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ - --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ - --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ - --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ - --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ - --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ - --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ - --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ - --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ - --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ - --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ - --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ - --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ - --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ - --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ - --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ - --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ - --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ - --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ - --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ - --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ - --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ - --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ - --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ - --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ - --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ - --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ - --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ - --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ - --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ - --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ - --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ - --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ - --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ - --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ - --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ - --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ - --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ - --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ - --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ - --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ - --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ - --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ - --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ - --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ - --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ - --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ - --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ - --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ - --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ - --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ - --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ - --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ - --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ - --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ - --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ - --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ - --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ - --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ - --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ - --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ - --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ - --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ - --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ - --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ - --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ - --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ - --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ - --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ - --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ - --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 +charset-normalizer==3.4.2 \ + --hash=sha256:005fa3432484527f9732ebd315da8da8001593e2cf46a3d817669f062c3d9ed4 \ + --hash=sha256:046595208aae0120559a67693ecc65dd75d46f7bf687f159127046628178dc45 \ + --hash=sha256:0c29de6a1a95f24b9a1aa7aefd27d2487263f00dfd55a77719b530788f75cff7 \ + --hash=sha256:0c8c57f84ccfc871a48a47321cfa49ae1df56cd1d965a09abe84066f6853b9c0 \ + --hash=sha256:0f5d9ed7f254402c9e7d35d2f5972c9bbea9040e99cd2861bd77dc68263277c7 \ + --hash=sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d \ + --hash=sha256:1b1bde144d98e446b056ef98e59c256e9294f6b74d7af6846bf5ffdafd687a7d \ + --hash=sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0 \ + --hash=sha256:1cad5f45b3146325bb38d6855642f6fd609c3f7cad4dbaf75549bf3b904d3184 \ + --hash=sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db \ + --hash=sha256:24498ba8ed6c2e0b56d4acbf83f2d989720a93b41d712ebd4f4979660db4417b \ + --hash=sha256:25a23ea5c7edc53e0f29bae2c44fcb5a1aa10591aae107f2a2b2583a9c5cbc64 \ + --hash=sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b \ + --hash=sha256:28a1005facc94196e1fb3e82a3d442a9d9110b8434fc1ded7a24a2983c9888d8 \ + --hash=sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff \ + --hash=sha256:36b31da18b8890a76ec181c3cf44326bf2c48e36d393ca1b72b3f484113ea344 \ + --hash=sha256:3c21d4fca343c805a52c0c78edc01e3477f6dd1ad7c47653241cf2a206d4fc58 \ + --hash=sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e \ + --hash=sha256:43e0933a0eff183ee85833f341ec567c0980dae57c464d8a508e1b2ceb336471 \ + --hash=sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148 \ + --hash=sha256:4e594135de17ab3866138f496755f302b72157d115086d100c3f19370839dd3a \ + --hash=sha256:50bf98d5e563b83cc29471fa114366e6806bc06bc7a25fd59641e41445327836 \ + --hash=sha256:5a9979887252a82fefd3d3ed2a8e3b937a7a809f65dcb1e068b090e165bbe99e \ + --hash=sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63 \ + --hash=sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c \ + --hash=sha256:6333b3aa5a12c26b2a4d4e7335a28f1475e0e5e17d69d55141ee3cab736f66d1 \ + --hash=sha256:65c981bdbd3f57670af8b59777cbfae75364b483fa8a9f420f08094531d54a01 \ + --hash=sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366 \ + --hash=sha256:6a0289e4589e8bdfef02a80478f1dfcb14f0ab696b5a00e1f4b8a14a307a3c58 \ + --hash=sha256:6b66f92b17849b85cad91259efc341dce9c1af48e2173bf38a85c6329f1033e5 \ + --hash=sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c \ + --hash=sha256:6fc1f5b51fa4cecaa18f2bd7a003f3dd039dd615cd69a2afd6d3b19aed6775f2 \ + --hash=sha256:70f7172939fdf8790425ba31915bfbe8335030f05b9913d7ae00a87d4395620a \ + --hash=sha256:721c76e84fe669be19c5791da68232ca2e05ba5185575086e384352e2c309597 \ + --hash=sha256:7222ffd5e4de8e57e03ce2cef95a4c43c98fcb72ad86909abdfc2c17d227fc1b \ + --hash=sha256:75d10d37a47afee94919c4fab4c22b9bc2a8bf7d4f46f87363bcf0573f3ff4f5 \ + --hash=sha256:76af085e67e56c8816c3ccf256ebd136def2ed9654525348cfa744b6802b69eb \ + --hash=sha256:770cab594ecf99ae64c236bc9ee3439c3f46be49796e265ce0cc8bc17b10294f \ + --hash=sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0 \ + --hash=sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941 \ + --hash=sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0 \ + --hash=sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86 \ + --hash=sha256:8272b73e1c5603666618805fe821edba66892e2870058c94c53147602eab29c7 \ + --hash=sha256:82d8fd25b7f4675d0c47cf95b594d4e7b158aca33b76aa63d07186e13c0e0ab7 \ + --hash=sha256:844da2b5728b5ce0e32d863af26f32b5ce61bc4273a9c720a9f3aa9df73b1455 \ + --hash=sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6 \ + --hash=sha256:915f3849a011c1f593ab99092f3cecfcb4d65d8feb4a64cf1bf2d22074dc0ec4 \ + --hash=sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0 \ + --hash=sha256:982bb1e8b4ffda883b3d0a521e23abcd6fd17418f6d2c4118d257a10199c0ce3 \ + --hash=sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1 \ + --hash=sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6 \ + --hash=sha256:a370b3e078e418187da8c3674eddb9d983ec09445c99a3a263c2011993522981 \ + --hash=sha256:a955b438e62efdf7e0b7b52a64dc5c3396e2634baa62471768a64bc2adb73d5c \ + --hash=sha256:aa6af9e7d59f9c12b33ae4e9450619cf2488e2bbe9b44030905877f0b2324980 \ + --hash=sha256:aa88ca0b1932e93f2d961bf3addbb2db902198dca337d88c89e1559e066e7645 \ + --hash=sha256:aaeeb6a479c7667fbe1099af9617c83aaca22182d6cf8c53966491a0f1b7ffb7 \ + --hash=sha256:aaf27faa992bfee0264dc1f03f4c75e9fcdda66a519db6b957a3f826e285cf12 \ + --hash=sha256:b2680962a4848b3c4f155dc2ee64505a9c57186d0d56b43123b17ca3de18f0fa \ + --hash=sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd \ + --hash=sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef \ + --hash=sha256:b3daeac64d5b371dea99714f08ffc2c208522ec6b06fbc7866a450dd446f5c0f \ + --hash=sha256:be1e352acbe3c78727a16a455126d9ff83ea2dfdcbc83148d2982305a04714c2 \ + --hash=sha256:bee093bf902e1d8fc0ac143c88902c3dfc8941f7ea1d6a8dd2bcb786d33db03d \ + --hash=sha256:c72fbbe68c6f32f251bdc08b8611c7b3060612236e960ef848e0a517ddbe76c5 \ + --hash=sha256:c9e36a97bee9b86ef9a1cf7bb96747eb7a15c2f22bdb5b516434b00f2a599f02 \ + --hash=sha256:cddf7bd982eaa998934a91f69d182aec997c6c468898efe6679af88283b498d3 \ + --hash=sha256:cf713fe9a71ef6fd5adf7a79670135081cd4431c2943864757f0fa3a65b1fafd \ + --hash=sha256:d11b54acf878eef558599658b0ffca78138c8c3655cf4f3a4a673c437e67732e \ + --hash=sha256:d41c4d287cfc69060fa91cae9683eacffad989f1a10811995fa309df656ec214 \ + --hash=sha256:d524ba3f1581b35c03cb42beebab4a13e6cdad7b36246bd22541fa585a56cccd \ + --hash=sha256:daac4765328a919a805fa5e2720f3e94767abd632ae410a9062dff5412bae65a \ + --hash=sha256:db4c7bf0e07fc3b7d89ac2a5880a6a8062056801b83ff56d8464b70f65482b6c \ + --hash=sha256:dc7039885fa1baf9be153a0626e337aa7ec8bf96b0128605fb0d77788ddc1681 \ + --hash=sha256:dccab8d5fa1ef9bfba0590ecf4d46df048d18ffe3eec01eeb73a42e0d9e7a8ba \ + --hash=sha256:dedb8adb91d11846ee08bec4c8236c8549ac721c245678282dcb06b221aab59f \ + --hash=sha256:e45ba65510e2647721e35323d6ef54c7974959f6081b58d4ef5d87c60c84919a \ + --hash=sha256:e53efc7c7cee4c1e70661e2e112ca46a575f90ed9ae3fef200f2a25e954f4b28 \ + --hash=sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691 \ + --hash=sha256:e70e990b2137b29dc5564715de1e12701815dacc1d056308e2b17e9095372a82 \ + --hash=sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a \ + --hash=sha256:e8323a9b031aa0393768b87f04b4164a40037fb2a3c11ac06a03ffecd3618027 \ + --hash=sha256:e92fca20c46e9f5e1bb485887d074918b13543b1c2a1185e69bb8d17ab6236a7 \ + --hash=sha256:eb30abc20df9ab0814b5a2524f23d75dcf83cde762c161917a2b4b7b55b1e518 \ + --hash=sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf \ + --hash=sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b \ + --hash=sha256:efd387a49825780ff861998cd959767800d54f8308936b21025326de4b5a42b9 \ + --hash=sha256:f0aa37f3c979cf2546b73e8222bbfa3dc07a641585340179d768068e3455e544 \ + --hash=sha256:f4074c5a429281bf056ddd4c5d3b740ebca4d43ffffe2ef4bf4d2d05114299da \ + --hash=sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509 \ + --hash=sha256:fb707f3e15060adf5b7ada797624a6c6e0138e2a26baa089df64c68ee98e040f \ + --hash=sha256:fcbe676a55d7445b22c10967bceaaf0ee69407fbe0ece4d032b6eb8d4565982a \ + --hash=sha256:fdb20a30fe1175ecabed17cbf7812f7b804b8a315a25f24678bcdf120a90077f # via requests -idna==3.7 \ - --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ - --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 +idna==3.10 \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 # via requests mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ @@ -111,7 +109,7 @@ mpmath==1.3.0 \ networkx==2.8.8 \ --hash=sha256:230d388117af870fce5647a3c52401fcf753e94720e6ea6b4197a5355648885e \ --hash=sha256:e435dfa75b1d7195c7b8378c3859f0445cd88c6b0375c181ed66823a9ceb7524 - # via -r requirements.in + # via -r tools/build/bazel/requirements.in numpy==1.26.4 \ --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ @@ -149,79 +147,79 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via -r requirements.in -pyyaml==6.0.1 \ - --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ - --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ - --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ - --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ - --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ - --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ - --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ - --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ - --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ - --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ - --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ - --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ - --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ - --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ - --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ - --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ - --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ - --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ - --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ - --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ - --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ - --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ - --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ - --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ - --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ - --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ - --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ - --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ - --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ - --hash=sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef \ - --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ - --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ - --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ - --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ - --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ - --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ - --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ - --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ - --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ - --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ - --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ - --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ - --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ - --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ - --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ - --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ - --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ - --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ - --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ - --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ - --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f - # via -r requirements.in + # via -r tools/build/bazel/requirements.in +pyyaml==6.0.2 \ + --hash=sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff \ + --hash=sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48 \ + --hash=sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086 \ + --hash=sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e \ + --hash=sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133 \ + --hash=sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5 \ + --hash=sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484 \ + --hash=sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee \ + --hash=sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5 \ + --hash=sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68 \ + --hash=sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a \ + --hash=sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf \ + --hash=sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99 \ + --hash=sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8 \ + --hash=sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85 \ + --hash=sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19 \ + --hash=sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc \ + --hash=sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a \ + --hash=sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1 \ + --hash=sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317 \ + --hash=sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c \ + --hash=sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631 \ + --hash=sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d \ + --hash=sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652 \ + --hash=sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5 \ + --hash=sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e \ + --hash=sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b \ + --hash=sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8 \ + --hash=sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476 \ + --hash=sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706 \ + --hash=sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563 \ + --hash=sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237 \ + --hash=sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b \ + --hash=sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083 \ + --hash=sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180 \ + --hash=sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425 \ + --hash=sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e \ + --hash=sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f \ + --hash=sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725 \ + --hash=sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183 \ + --hash=sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab \ + --hash=sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774 \ + --hash=sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725 \ + --hash=sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e \ + --hash=sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5 \ + --hash=sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d \ + --hash=sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290 \ + --hash=sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44 \ + --hash=sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed \ + --hash=sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4 \ + --hash=sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba \ + --hash=sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12 \ + --hash=sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4 + # via -r tools/build/bazel/requirements.in requests==2.32.2 \ --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c - # via -r requirements.in + # via -r tools/build/bazel/requirements.in +setuptools==78.1.1 \ + --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ + --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d + # via -r tools/build/bazel/requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 - # via -r requirements.in + # via -r tools/build/bazel/requirements.in typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a - # via -r requirements.in + # via -r tools/build/bazel/requirements.in urllib3==2.5.0 \ --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc # via requests - -# The following packages are considered to be unsafe in a requirements file: -setuptools==78.1.1 \ - --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ - --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d - # via -r requirements.in diff --git a/tools/setup_helpers/build.bzl b/tools/setup_helpers/build.bzl index c5be13e4603b4..5210b6d485552 100644 --- a/tools/setup_helpers/build.bzl +++ b/tools/setup_helpers/build.bzl @@ -4,7 +4,7 @@ def define_targets(rules): srcs = ["generate_code.py"], visibility = ["//:__pkg__"], deps = [ - rules.requirement("PyYAML"), + rules.requirement("pyyaml"), "//tools/autograd", "//torchgen", ], diff --git a/torchgen/build.bzl b/torchgen/build.bzl index 50765869f8d5d..0adcf24e1a4c2 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -4,7 +4,7 @@ def define_targets(rules): srcs = rules.glob(["**/*.py"]), visibility = ["//visibility:public"], deps = [ - rules.requirement("PyYAML"), + rules.requirement("pyyaml"), rules.requirement("typing-extensions"), ], ) @@ -14,7 +14,7 @@ def define_targets(rules): srcs = [":torchgen"], visibility = ["//visibility:public"], deps = [ - rules.requirement("PyYAML"), + rules.requirement("pyyaml"), rules.requirement("typing-extensions"), ], ) From bbc32d680fdd6c23ee0e57d18f5643edd0750a3f Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Sun, 20 Jul 2025 16:39:36 -0700 Subject: [PATCH 311/457] [SymmMem] Add NVSHMEM sync_all support into Triton (#158512) Adds `sync_all()` function for local store visibility synchronization in NVSHMEM Triton kernels. Provides memory ordering for local operations without remote completion guarantees. Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_sync` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158512 Approved by: https://github.com/fduwjj ghstack dependencies: #158511 --- test/distributed/test_nvshmem_triton.py | 66 +++++++++++++++++++ .../_symmetric_memory/_nvshmem_triton.py | 11 ++++ 2 files changed, 77 insertions(+) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 996ad25db5d7e..6c7f38686a4c6 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -181,6 +181,35 @@ def barrier_test_kernel( tl.store(p_dst, received + 1) +@triton.jit +def sync_test_kernel( + dst_ptr, + src_ptr, + numel, +): + my_pe = nvshmem.my_pe() + n_pes = nvshmem.n_pes() + + # Rank 0 broadcasts its value to all other ranks + if my_pe == 0: + # Write initial value + p_src = src_ptr.to(tl.pointer_type(tl.int32)) + tl.store(p_src, 42) + # Put to all other ranks + i = 1 + while i < n_pes: + nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + i += 1 + # Synchronize all PEs (this is more lightweight than barrier_all() b/c it only ensures local store visibility + # and doesn't wait for remote ops to complete) + nvshmem.sync_all() + # Non-zero ranks increment the received value + if my_pe != 0: + p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + received = tl.load(p_dst) + tl.store(p_dst, received + 1) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -765,6 +794,43 @@ def test_triton_barrier(self) -> None: dst, torch.tensor([43], device=self.device, dtype=dtype) ) + @skipIfRocm + @requires_triton() + def test_triton_sync(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + numel = 1 + dtype = torch.int32 + # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Launch kernel with cooperative grid + sync_test_kernel[(1,)]( + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + numel=numel, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + num_ctas=1, + ) + # Verify results + if rank == 0: + # Rank 0 should have its original value (42) in src + torch.testing.assert_close( + src, torch.tensor([42], device=self.device, dtype=dtype) + ) + else: + # Other ranks should have received 42 and incremented to 43 + torch.testing.assert_close( + dst, torch.tensor([43], device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 3cd8c99d34f36..ed195a4225f13 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -230,3 +230,14 @@ def barrier_all(_builder=None): # type: ignore[no-untyped-def] is_pure=False, _builder=_builder, ) + + @core.extern + def sync_all(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_sync_all", core.dtype("int32"))}, + is_pure=False, + _builder=_builder, + ) From 1eb6b2089fbc1a01e38448222ef0e6daa7504924 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Mon, 21 Jul 2025 03:43:25 +0000 Subject: [PATCH 312/457] [Inductor] Set the default value of min_chunk_size to 512 (#150762) Change the default value of min_chunk_size from 4096 to 512 to allow more for loops to be parallelized. I tested the Inductor benchmark with this PR on CPU, and saw ~10% improvement in torchbench geomean speedup, and no change in huggingface/timm_models. There are about 15 torchbench models with different degrees of performance improvement, among which functorch_dp_cifar10, opacus_cifar10, hf_Reformer, and pyhpc_turbulent_kinetic_energy have more than 50% performance improvement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150762 Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel --- .../codegen/cpp_flex_attention_template.py | 12 ++++++++-- .../_inductor/codegen/cpp_template_kernel.py | 23 ++++++++++++++++--- torch/_inductor/config.py | 2 +- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index 64e11b00fcc04..80fd3014a643c 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -814,7 +814,7 @@ def modification(self, subgraph_buffer, output_name, output_idx): from ..loop_body import LoopBody from ..utils import sympy_index_symbol_with_prefix, SymT from ..virtualized import V - from .cpp import CppKernelProxy, KernelGroup + from .cpp import CppKernelProxy, KernelGroup, ParallelDepth kernel_group = KernelGroup() kernel_input_args = { @@ -883,7 +883,15 @@ def fn(*args): var_sizes_list.append((var_sizes, ())) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) output_code = kernel_group.loops_code.getvalue() var_q_symbol, var_kv_symbol = self.block_vars diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index b7a830a501051..184c0fe889af9 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -2,6 +2,7 @@ import itertools from collections.abc import Iterable from typing import Any, Callable, Optional, Union +from unittest.mock import patch import sympy from sympy.parsing.sympy_parser import parse_expr @@ -18,7 +19,7 @@ from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V from .common import REMOVED -from .cpp import CppKernel, CppKernelProxy, KernelGroup +from .cpp import CppKernel, CppKernelProxy, KernelGroup, ParallelDepth from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext @@ -288,7 +289,15 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) return kernel_group.loops_code.getvalue() def store_grouped_gemm_pointwise_nodes( @@ -342,7 +351,15 @@ def fn(*args): var_sizes_list.append(var_sizes) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) return kernel_group.loops_code.getvalue() def store_output( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ef7de961149e5..25bd81ea7f8af 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1024,7 +1024,7 @@ class cpp: dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" simdlen: Optional[int] = None - min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512")) cxx: tuple[Literal[None], str] = ( None, # download gcc12 from conda-forge if conda is installed From 979fae761c544b207762b9f679894a2a90bd30d5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 20 Jul 2025 21:27:45 -0700 Subject: [PATCH 313/457] Rename modules in AOTAutograd (#158449) Fixes https://github.com/pytorch/pytorch/issues/158382 ``` renamed: torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py -> torch/_functorch/_aot_autograd/graph_capture.py renamed: torch/_functorch/_aot_autograd/traced_function_transforms.py -> torch/_functorch/_aot_autograd/graph_capture_wrappers.py renamed: torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py -> torch/_functorch/_aot_autograd/graph_compile.py ``` Everything else is ONLY import changes. I did not rename any functions even if we probably should have. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158449 Approved by: https://github.com/jamesjwu --- test/dynamo/test_aot_autograd.py | 12 ++++---- test/functorch/test_aotdispatch.py | 2 +- test/inductor/test_auto_functionalize.py | 2 +- test/test_dynamic_shapes.py | 8 +++--- ..._and_compile_graph.py => graph_capture.py} | 4 +-- ...ransforms.py => graph_capture_wrappers.py} | 0 ...e_runtime_wrappers.py => graph_compile.py} | 7 ++--- .../_aot_autograd/runtime_wrappers.py | 2 +- torch/_functorch/aot_autograd.py | 28 +++++++++---------- torch/export/_trace.py | 4 +-- 10 files changed, 32 insertions(+), 37 deletions(-) rename torch/_functorch/_aot_autograd/{dispatch_and_compile_graph.py => graph_capture.py} (99%) rename torch/_functorch/_aot_autograd/{traced_function_transforms.py => graph_capture_wrappers.py} (100%) rename torch/_functorch/_aot_autograd/{jit_compile_runtime_wrappers.py => graph_compile.py} (99%) diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index af162b41ccd76..0de83cd2dc317 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1213,7 +1213,7 @@ def fn(x): @torch._functorch.config.patch(donated_buffer=True) def test_donated_buffer1(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" @torch.compile() def relu(x): @@ -1233,7 +1233,7 @@ def relu(x): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer2(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 @torch.compile() @@ -1255,7 +1255,7 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer3(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 @torch.compile() @@ -1278,7 +1278,7 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer4(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" class Mod(torch.nn.Module): def __init__(self) -> None: @@ -1309,7 +1309,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer5(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" @torch.compile() def f(x, z): @@ -1346,7 +1346,7 @@ def test_donated_buffer6(self): # SymNodeVariable() is not a constant return - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" def fn(x): p = torch.nn.Parameter(x + 123) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 698bab89935fb..869fc6964f2fd 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7469,7 +7469,7 @@ def test_saved_tensors_hooks_donated_buffers(self): "pack_hash", "unpack_hash", ) - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" class SAF(torch.autograd.Function): @staticmethod diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index da146acd6368f..c91dde52780ac 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -433,7 +433,7 @@ def run_aot_eager(self, f, orig_args, _dynamic=False): aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) result = None diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0f299cd6b6c79..af16a8f325fc3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3211,7 +3211,7 @@ def make_non_contiguous_tensor_and_test(cnt): self.assertEqual(compiled_result, eager_result) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): make_non_contiguous_tensor_and_test(4) @@ -3246,7 +3246,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", torch._dynamo.decorators.mark_unbacked(x, 0) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): compiled_result = compiled_func(x, torch.tensor([10])) @@ -3305,7 +3305,7 @@ def func(x, y): torch._dynamo.decorators.mark_unbacked(x, 1) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): result_eager = func(x, torch.tensor([5, 20])) @@ -3355,7 +3355,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", # Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride. log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/graph_capture.py similarity index 99% rename from torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py rename to torch/_functorch/_aot_autograd/graph_capture.py index be3226ca01f57..f4710bc8000ce 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -23,8 +23,7 @@ assert_functional_graph, propagate_input_mutation_stacktraces, ) -from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta -from .traced_function_transforms import ( +from .graph_capture_wrappers import ( aot_dispatch_subclass, create_functionalized_fn, create_joint, @@ -32,6 +31,7 @@ fn_prepped_for_autograd, handle_effect_tokens_fn, ) +from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta from .utils import ( copy_fwd_metadata_to_bw_nodes, register_buffer_assignment_hook, diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py similarity index 100% rename from torch/_functorch/_aot_autograd/traced_function_transforms.py rename to torch/_functorch/_aot_autograd/graph_capture_wrappers.py diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/graph_compile.py similarity index 99% rename from torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py rename to torch/_functorch/_aot_autograd/graph_compile.py index 73d6ab1c19596..c197dcfbd8ed7 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -51,10 +51,7 @@ should_bundle_autograd_cache, should_use_remote_autograd_cache, ) -from .dispatch_and_compile_graph import ( - aot_dispatch_autograd_graph, - aot_dispatch_base_graph, -) +from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph from .logging_utils import track_graph_compiling from .runtime_wrappers import ( AOTDedupeWrapper, @@ -896,7 +893,7 @@ def _wrapper(*args): def prepare_hook_gm(aot_config, fn, args): - from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph + from torch._functorch._aot_autograd.graph_capture import _create_graph fn, args = create_wrap_fn(fn, args) gm = _create_graph(fn, args, aot_config=aot_config) diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 46bd0ad774793..fbdc81670a329 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -41,6 +41,7 @@ from .. import config from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata from .functional_utils import gen_alias_from_base +from .graph_capture_wrappers import aot_dispatch_subclass from .input_output_analysis import ( compute_overlapping_inputs, create_synthetic_base_metadata, @@ -65,7 +66,6 @@ runtime_unwrap_tensor_subclasses, wrap_tensor_subclasses, ) -from .traced_function_transforms import aot_dispatch_subclass from .utils import ( call_func_at_runtime_with_args, make_boxed_func, diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 824fa1e0c25c8..495193c89f61c 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -62,17 +62,26 @@ sync_functional_tensor, to_fun, ) +from ._aot_autograd.graph_capture_wrappers import ( # noqa: F401 + aot_dispatch_subclass, + create_functional_call, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) +from ._aot_autograd.graph_compile import ( # noqa: F401 + aot_stage1_graph_capture, + aot_stage2_compile, + aot_stage2_export, +) from ._aot_autograd.input_output_analysis import ( # noqa: F401 compute_overlapping_inputs, create_graph_signature, create_synthetic_base_metadata, remove_dupe_metadata, ) -from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 - aot_stage1_graph_capture, - aot_stage2_compile, - aot_stage2_export, -) from ._aot_autograd.logging_utils import ( # noqa: F401 callback_set, describe_input, @@ -118,15 +127,6 @@ wrap_tensor_subclasses, wrap_tensor_subclasses_maybe_joint, ) -from ._aot_autograd.traced_function_transforms import ( # noqa: F401 - aot_dispatch_subclass, - create_functional_call, - create_functionalized_fn, - create_functionalized_rng_ops_wrapper, - create_joint, - fn_input_mutations_to_outputs, - fn_prepped_for_autograd, -) from ._aot_autograd.utils import ( # noqa: F401 _get_autocast_states, _get_symint_hints, diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 09163e3bffa8a..4183fe22cda85 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -49,15 +49,13 @@ ) from torch._export.verifier import SpecViolationError from torch._export.wrappers import _wrap_submodules +from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call from torch._functorch._aot_autograd.input_output_analysis import ( _graph_input_names, _graph_output_names, ) from torch._functorch._aot_autograd.schemas import GraphSignature from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container -from torch._functorch._aot_autograd.traced_function_transforms import ( - create_functional_call, -) from torch._functorch._aot_autograd.utils import ( create_tree_flattened_fn, register_buffer_assignment_hook, From d5a29fc58a0f974871841075072164f852c61b65 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 20 Jul 2025 21:27:45 -0700 Subject: [PATCH 314/457] De-abstract premature generalization with InductorWrapper (#158528) See docblock on InductorWrapper for the distinction. This will matter on a later refactor PR where I will change the signature for one of these but not the other. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/158528 Approved by: https://github.com/jamesjwu ghstack dependencies: #158449 --- .../_functorch/_aot_autograd/graph_compile.py | 24 +----- .../_aot_autograd/runtime_wrappers.py | 13 ++- torch/_functorch/_aot_autograd/schemas.py | 80 +++++++++++++++++-- 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index c197dcfbd8ed7..cc64c82c2920c 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -236,19 +236,11 @@ def aot_stage2_inference( ) fakified_out_wrapper = FakifiedOutWrapper() - ( - fw_module, - updated_flat_args, - fw_metadata, - ) = fakified_out_wrapper.pre_compile( + fakified_out_wrapper.pre_compile( fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata ) functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() - ( - fw_module, - updated_flat_args, - fw_metadata, - ) = functionalized_rng_wrapper.pre_compile( + functionalized_rng_wrapper.pre_compile( fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata ) assert isinstance(fw_module, GraphModule) @@ -1612,11 +1604,7 @@ def aot_stage2_autograd( adjusted_flat_args = joint_inputs[0] fakified_out_wrapper = FakifiedOutWrapper() - ( - fw_module, - adjusted_flat_args, - fw_metadata, - ) = fakified_out_wrapper.pre_compile( + fakified_out_wrapper.pre_compile( fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata ) @@ -1633,11 +1621,7 @@ def aot_stage2_autograd( ] adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] - ( - fw_module, - adjusted_flat_args, - fw_metadata, - ) = functionalized_rng_wrapper.pre_compile( + functionalized_rng_wrapper.pre_compile( fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata ) if tracing_context := torch._guards.TracingContext.try_get(): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index fbdc81670a329..805bb5d79c8ad 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -51,6 +51,7 @@ from .schemas import ( AOTConfig, CompilerWrapper, + InductorWrapper, InputAliasInfo, MemoryFormatMeta, MutationType, @@ -458,7 +459,7 @@ def _runtime_wrapper(*args, **kwargs): @dataclass -class FunctionalizedRngRuntimeWrapper(CompilerWrapper): +class FunctionalizedRngRuntimeWrapper(InductorWrapper): # TODO: I would love to get rid of this argument, but it's # Wrapped pretty tightly around our aot_dispatch_autograd logic. # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices @@ -470,12 +471,12 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper): def pre_compile( self, - flat_fn, + flat_fn: torch.fx.GraphModule, flat_args, aot_config, *, fw_metadata, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + ) -> None: if config.functionalize_rng_ops: # Update example inputs for the fw_compiler fake_mode = detect_fake_mode() @@ -484,7 +485,6 @@ def pre_compile( # We are not clearing flat_args here because # 1) There is a check in the debug compiler at the end # 2) It does not matter as these are fake tensors - return flat_fn, flat_args, fw_metadata def post_compile( self, @@ -533,7 +533,7 @@ def _functionalized_rng_runtime_epilogue( @dataclass -class FakifiedOutWrapper(CompilerWrapper): +class FakifiedOutWrapper(InductorWrapper): out_metas: list[torch.Tensor] = field(default_factory=list) # TracingContext.fwd_output_strides # Generated from actually doing compile @@ -548,7 +548,7 @@ def pre_compile( aot_config, *, fw_metadata, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + ) -> None: tracing_context = torch._guards.TracingContext.try_get() if tracing_context and tracing_context.fakify_first_call: self.out_metas = [ @@ -556,7 +556,6 @@ def pre_compile( ] else: self.needs_post_compile = False - return fw_module, flat_args, fw_metadata def _compute_output_meta_with_inductor_strides(self): out = self.out_metas diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index f8b60d4f7060c..efb16234c20cc 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -1042,13 +1042,28 @@ class AOTState: class CompilerWrapper: """ - A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: + AOTAutograd needs to do many transformations to the calling convention of the user function + it is tracing, e.g., deduplicating inputs, unpacking subclasses, etc. CompilerWrapper lets + us factor these into compositional stages so we can handle each transformation incrementally + instead of having to do it all at once. - 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) - 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) + Since there is a calling convention change, there are two parts to the wrpaper: + + 1. The prologue, which is about compile-time behavior: given this original function, what + is the new function with modified calling convention that we should trace with AOTAutograd + to get the FX graph we will do joint passes, partitioning and ultimate Inductor compilation on? + We get (flat_fn, flat_args), the original function under trace and inputs we were + going to feed it, and produce a new function and new inputs to feed it. + + 2. The epilogue, which is about run-time behavior: we have now compiled the modified calling + convention function, we need to wrap it so that we have a new function that has the + original calling convention of the original function, so that our users can call it + at the old signature they expected. We get (compiled_fn, real arguments), the newly + compiled function we need to wrap. + + Note about caching: we do NOT directly serialize the runtime wrappers; instead, they + are reapplied to compiled_fn after we have finished deserializing the compiled_fn. - Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate - caching on the compiled output, and re-wrapping the output via epilogues. Extra metadata that is needed to compute pre or post compile can be passed in via attributes. """ @@ -1088,6 +1103,61 @@ def wrapped_compiled_fn(args): return compiled_fn +class InductorWrapper: + """ + This is sort of like CompilerWrapper, but it happens at a different part of the lifecycle: + it talks about transformations we do to the traced and partitioned FX graph before we + send it to the Inductor compiler. + + Once again, there are two parts: + + 1. The prologue, which "modifies" the FX graph before we send it to + Inductor. I say "modifies" because... we don't really actually do + anything nontrivial in either of our two implementations. + 2. The epilogue, which modifies the compiled function produced by Inductor + + Although hypothetically these wrappers could be used compositionally in a centralized + wrappers list, in practice they seem to just be invoked manually when needed. + + NB: The flat_args input is sometimes mutated. This is probably naughty but whatever. + """ + + def pre_compile( + self, + fw_module: torch.fx.GraphModule, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> None: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + @dataclass class AOTGraphCapture: # Produced by aot_stage1_graph_capture # AOTAutograd typically operates by taking complicated graphs and From 8e57cdb746b4ab28865fdf01532f87b0d21700e9 Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 18 Jul 2025 10:05:09 -0700 Subject: [PATCH 315/457] Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048) When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact. It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048 Approved by: https://github.com/zhxchen17 --- test/dynamo/test_package.py | 34 +++++++++++++++++++++++ torch/_dynamo/precompile_context.py | 9 ++++-- torch/_inductor/compile_fx.py | 29 ++++++++++++++++++- torch/_inductor/runtime/autotune_cache.py | 10 +++++++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 3160007774090..51f6ca91136c9 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -15,6 +15,7 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._functorch import config as functorch_config +from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -428,6 +429,39 @@ def fn2(x): self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + @parametrize("device", ("cuda", "xpu")) + @torch._dynamo.config.patch(caching_precompile=True) + def test_automatic_dynamo_autotune_cache(self, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + def fn(x, y): + return x.sin() + y + + arg1 = torch.randn(3, 3, device=device) + arg2 = torch.randn(3, 3, device=device) + expected = fn(arg1, arg2).clone() + + with PatchCaches(): + compiled_fn1 = torch.compile(fn, mode="max-autotune") + result = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result) + self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) + DynamoCache.clear() + + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER + self._save_and_reload( + expected_backends=1, expected_dynamo=1, expected_autotune=1 + ) + compiled_fn1 = torch.compile(fn, mode="max-autotune") + with torch.compiler.set_stance("fail_on_recompile"): + result1 = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result1) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) + @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 6bb42bb34bc35..040f54ce70db2 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -70,7 +70,8 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact - - CodeStateArtifact (from torch._dynamo.package once available) + - DynamoCodeStateArtifact + - AutotuneCacheArtifact (regular autotune results, same as Megacache) """ # Protected by the compile_lock @@ -149,8 +150,12 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: artifacts_by_key = {} cache_info = CacheInfo() for artifact in chain(*artifacts.values()): + if artifact.type() == "autotune": + # Populate autotune cache artifacts + artifact.populate_cache() + else: + artifacts_by_key[artifact.key] = artifact cache_info.add(artifact) - artifacts_by_key[artifact.key] = artifact from torch._dynamo.package import _BackendId, DynamoCache diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index c14f3fd7d534f..4d02c353693f7 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -909,10 +909,37 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") + if torch._functorch.config.bundled_autograd_cache: + assert mb_compiled_graph is None + assert cache_info is None + # When using bundled autograd cache, we still want + # to use the TritonBundler, but we don't want to save + # the results here. The results will get saved directly + # to AOTAutogradCache. + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + # CACHE BYPASS: Compile the graph, don't save it to the cache # (this can happen either because cache was disabled, or we # determined the input is uncacheable) - if cache_info is None or cache_info["cache_state"] == "bypass": + elif cache_info is None or cache_info["cache_state"] == "bypass": assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 01d038aab8e7b..88b9c80c77146 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,6 +35,7 @@ from typing_extensions import override import torch +from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -125,6 +126,7 @@ def create( ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -300,6 +302,10 @@ def save( CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -625,6 +631,10 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) return result @override From 393377d2156cf4dfb0a7d53c79a85a8b24055ae0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 13:58:50 +0000 Subject: [PATCH 316/457] Revert "[CI] update flake8 and mypy lint dependencies (#158720)" This reverts commit a527e816935957a164d74dd7c5069310b2857695. Reverted https://github.com/pytorch/pytorch/pull/158720 on behalf of https://github.com/malfet due to This broke lint, see https://hud.pytorch.org/hud/pytorch/pytorch/8e57cdb746b4ab28865fdf01532f87b0d21700e9/1?per_page=50&name_filter=lintrunner-noclang ([comment](https://github.com/pytorch/pytorch/pull/158720#issuecomment-3096893256)) --- .ci/docker/requirements-ci.txt | 6 +- .../actions/filter-test-configs/action.yml | 2 +- .github/requirements-gha-cache.txt | 4 +- .../requirements/pip-requirements-macOS.txt | 2 +- .../workflows/check_mergeability_ghstack.yml | 2 +- .github/workflows/cherry-pick.yml | 2 +- .github/workflows/revert.yml | 2 +- .github/workflows/trymerge.yml | 2 +- .github/workflows/tryrebase.yml | 2 +- .lintrunner.toml | 28 +- tools/build/bazel/requirements.in | 2 +- tools/build/bazel/requirements.txt | 332 +++++++++--------- tools/setup_helpers/build.bzl | 2 +- torchgen/build.bzl | 4 +- 14 files changed, 197 insertions(+), 195 deletions(-) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 4a52ad5e951bc..fb773ff324af8 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -221,9 +221,9 @@ pygments==2.15.0 #Pinned versions: 2.12.0 #test that import: the doctests -#pyyaml +#PyYAML #Description: data serialization format -#Pinned versions: 6.0.2 +#Pinned versions: #test that import: #requests @@ -233,7 +233,7 @@ pygments==2.15.0 #rich #Description: rich text and beautiful formatting in the terminal -#Pinned versions: 14.0.0 +#Pinned versions: 10.9.0 #test that import: scikit-image==0.19.3 ; python_version < "3.10" diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 338fc0c2a844c..ca6643f9e2fc1 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -70,7 +70,7 @@ runs: set -eux # PyYAML 6.0 doesn't work with MacOS x86 anymore # This must run on Python-3.7 (AmazonLinux2) so can't use request=3.32.2 - python3 -m pip install requests==2.27.1 pyyaml==6.0.2 + python3 -m pip install requests==2.27.1 pyyaml==6.0.1 - name: Parse ref id: parse-ref diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 8c4a877fdd193..5c691e4bf9b31 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -10,6 +10,6 @@ jinja2==3.1.6 lintrunner==0.10.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 -pyyaml==6.0.2 +pyyaml==6.0 requests==2.32.4 -rich==14.0.0 +rich==10.9.0 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 0f8276f1dda63..9c72c71523b7d 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -2,7 +2,7 @@ boto3==1.35.42 cmake==3.27.* expecttest==0.3.0 fbscribelogger==0.1.7 -filelock==3.18.0 +filelock==3.6.0 hypothesis==6.56.4 librosa>=0.6.2 mpmath==1.3.0 diff --git a/.github/workflows/check_mergeability_ghstack.yml b/.github/workflows/check_mergeability_ghstack.yml index 569a174665ba8..65193839e9b9d 100644 --- a/.github/workflows/check_mergeability_ghstack.yml +++ b/.github/workflows/check_mergeability_ghstack.yml @@ -56,7 +56,7 @@ jobs: cache: pip architecture: x64 - - run: pip install pyyaml==6.0.2 + - run: pip install pyyaml==6.0 shell: bash - name: Verify mergeability diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml index 310857782ea14..1d385b556277a 100644 --- a/.github/workflows/cherry-pick.yml +++ b/.github/workflows/cherry-pick.yml @@ -26,7 +26,7 @@ jobs: cache: pip # Not the direct dependencies but the script uses trymerge - - run: pip install pyyaml==6.0.2 + - run: pip install pyyaml==6.0 - name: Setup committer id run: | diff --git a/.github/workflows/revert.yml b/.github/workflows/revert.yml index 226d773e48977..3c8722930e22e 100644 --- a/.github/workflows/revert.yml +++ b/.github/workflows/revert.yml @@ -26,7 +26,7 @@ jobs: architecture: x64 check-latest: false cache: pip - - run: pip install pyyaml==6.0.2 + - run: pip install pyyaml==6.0 - name: Setup committer id run: | diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index 1fdb1da67a595..19e169bd973b3 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -28,7 +28,7 @@ jobs: check-latest: false cache: pip architecture: x64 - - run: pip install pyyaml==6.0.2 + - run: pip install pyyaml==6.0 - name: Setup committer id run: | diff --git a/.github/workflows/tryrebase.yml b/.github/workflows/tryrebase.yml index 1a8e00e4390be..9af59bcb3662d 100644 --- a/.github/workflows/tryrebase.yml +++ b/.github/workflows/tryrebase.yml @@ -25,7 +25,7 @@ jobs: architecture: x64 check-latest: false cache: pip - - run: pip install pyyaml==6.0.2 + - run: pip install pyyaml==6.0 - name: Setup committer id run: | diff --git a/.lintrunner.toml b/.lintrunner.toml index 6cc1164a785dd..04664378d8bf8 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -39,16 +39,16 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'flake8==7.3.0', - 'flake8-bugbear==24.12.12', - 'flake8-comprehensions==3.16.0', + 'flake8==6.1.0', + 'flake8-bugbear==23.3.23', + 'flake8-comprehensions==3.15.0', 'flake8-executable==2.1.3', - 'flake8-logging-format==2024.24.12', - 'flake8-pyi==25.5.0', - 'flake8-simplify==0.22.0', + 'flake8-logging-format==0.9.0', + 'flake8-pyi==23.3.1', + 'flake8-simplify==0.19.3', 'mccabe==0.7.0', - 'pycodestyle==2.14.0', - 'pyflakes==3.4.0', + 'pycodestyle==2.11.1', + 'pyflakes==3.1.0', 'torchfix==0.4.0 ; python_version >= "3.9" and python_version < "3.13"', ] @@ -158,16 +158,16 @@ init_command = [ 'mypy==1.16.0', 'sympy==1.13.3', 'types-requests==2.27.25', - 'types-pyyaml==6.0.2', + 'types-pyyaml==6.0.1', 'types-tabulate==0.8.8', 'types-protobuf==5.29.1.20250403', 'types-setuptools==79.0.0.20250422', 'types-jinja2==2.11.9', 'types-colorama==0.4.6', - 'filelock==3.18.0', + 'filelock==3.13.1', 'junitparser==2.1.1', - 'rich==14.0.0', - 'pyyaml==6.0.2', + 'rich==10.9.0', + 'pyyaml==6.0.1', 'optree==0.13.0', 'dataclasses-json==0.6.7', 'pandas==2.2.3', @@ -1111,7 +1111,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'pyyaml==6.0.2', + 'PyYAML==6.0.1', ] [[linter]] @@ -1133,7 +1133,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'pyyaml==6.0.2', + 'PyYAML==6.0.1', ] [[linter]] diff --git a/tools/build/bazel/requirements.in b/tools/build/bazel/requirements.in index 8837501006624..37750163da81e 100644 --- a/tools/build/bazel/requirements.in +++ b/tools/build/bazel/requirements.in @@ -1,4 +1,4 @@ -pyyaml==6.0.2 +PyYAML==6.0.1 numpy==1.26.4 requests==2.32.2 setuptools==78.1.1 diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index a3383b60c1964..a15924660167d 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -1,106 +1,108 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile --generate-hashes tools/build/bazel/requirements.in -certifi==2025.7.14 \ - --hash=sha256:6b31f564a415d79ee77df69d757bb49a5bb53bd9f756cbbe24394ffd6fc1f4b2 \ - --hash=sha256:8ea99dbdfaaf2ba2f9bac77b9249ef62ec5218e7c2b2e903378ed5fccf765995 +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --allow-unsafe --generate-hashes tools/build/bazel/requirements.in +# +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests -charset-normalizer==3.4.2 \ - --hash=sha256:005fa3432484527f9732ebd315da8da8001593e2cf46a3d817669f062c3d9ed4 \ - --hash=sha256:046595208aae0120559a67693ecc65dd75d46f7bf687f159127046628178dc45 \ - --hash=sha256:0c29de6a1a95f24b9a1aa7aefd27d2487263f00dfd55a77719b530788f75cff7 \ - --hash=sha256:0c8c57f84ccfc871a48a47321cfa49ae1df56cd1d965a09abe84066f6853b9c0 \ - --hash=sha256:0f5d9ed7f254402c9e7d35d2f5972c9bbea9040e99cd2861bd77dc68263277c7 \ - --hash=sha256:18dd2e350387c87dabe711b86f83c9c78af772c748904d372ade190b5c7c9d4d \ - --hash=sha256:1b1bde144d98e446b056ef98e59c256e9294f6b74d7af6846bf5ffdafd687a7d \ - --hash=sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0 \ - --hash=sha256:1cad5f45b3146325bb38d6855642f6fd609c3f7cad4dbaf75549bf3b904d3184 \ - --hash=sha256:21b2899062867b0e1fde9b724f8aecb1af14f2778d69aacd1a5a1853a597a5db \ - --hash=sha256:24498ba8ed6c2e0b56d4acbf83f2d989720a93b41d712ebd4f4979660db4417b \ - --hash=sha256:25a23ea5c7edc53e0f29bae2c44fcb5a1aa10591aae107f2a2b2583a9c5cbc64 \ - --hash=sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b \ - --hash=sha256:28a1005facc94196e1fb3e82a3d442a9d9110b8434fc1ded7a24a2983c9888d8 \ - --hash=sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff \ - --hash=sha256:36b31da18b8890a76ec181c3cf44326bf2c48e36d393ca1b72b3f484113ea344 \ - --hash=sha256:3c21d4fca343c805a52c0c78edc01e3477f6dd1ad7c47653241cf2a206d4fc58 \ - --hash=sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e \ - --hash=sha256:43e0933a0eff183ee85833f341ec567c0980dae57c464d8a508e1b2ceb336471 \ - --hash=sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148 \ - --hash=sha256:4e594135de17ab3866138f496755f302b72157d115086d100c3f19370839dd3a \ - --hash=sha256:50bf98d5e563b83cc29471fa114366e6806bc06bc7a25fd59641e41445327836 \ - --hash=sha256:5a9979887252a82fefd3d3ed2a8e3b937a7a809f65dcb1e068b090e165bbe99e \ - --hash=sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63 \ - --hash=sha256:5bf4545e3b962767e5c06fe1738f951f77d27967cb2caa64c28be7c4563e162c \ - --hash=sha256:6333b3aa5a12c26b2a4d4e7335a28f1475e0e5e17d69d55141ee3cab736f66d1 \ - --hash=sha256:65c981bdbd3f57670af8b59777cbfae75364b483fa8a9f420f08094531d54a01 \ - --hash=sha256:68a328e5f55ec37c57f19ebb1fdc56a248db2e3e9ad769919a58672958e8f366 \ - --hash=sha256:6a0289e4589e8bdfef02a80478f1dfcb14f0ab696b5a00e1f4b8a14a307a3c58 \ - --hash=sha256:6b66f92b17849b85cad91259efc341dce9c1af48e2173bf38a85c6329f1033e5 \ - --hash=sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c \ - --hash=sha256:6fc1f5b51fa4cecaa18f2bd7a003f3dd039dd615cd69a2afd6d3b19aed6775f2 \ - --hash=sha256:70f7172939fdf8790425ba31915bfbe8335030f05b9913d7ae00a87d4395620a \ - --hash=sha256:721c76e84fe669be19c5791da68232ca2e05ba5185575086e384352e2c309597 \ - --hash=sha256:7222ffd5e4de8e57e03ce2cef95a4c43c98fcb72ad86909abdfc2c17d227fc1b \ - --hash=sha256:75d10d37a47afee94919c4fab4c22b9bc2a8bf7d4f46f87363bcf0573f3ff4f5 \ - --hash=sha256:76af085e67e56c8816c3ccf256ebd136def2ed9654525348cfa744b6802b69eb \ - --hash=sha256:770cab594ecf99ae64c236bc9ee3439c3f46be49796e265ce0cc8bc17b10294f \ - --hash=sha256:7a6ab32f7210554a96cd9e33abe3ddd86732beeafc7a28e9955cdf22ffadbab0 \ - --hash=sha256:7c48ed483eb946e6c04ccbe02c6b4d1d48e51944b6db70f697e089c193404941 \ - --hash=sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0 \ - --hash=sha256:8075c35cd58273fee266c58c0c9b670947c19df5fb98e7b66710e04ad4e9ff86 \ - --hash=sha256:8272b73e1c5603666618805fe821edba66892e2870058c94c53147602eab29c7 \ - --hash=sha256:82d8fd25b7f4675d0c47cf95b594d4e7b158aca33b76aa63d07186e13c0e0ab7 \ - --hash=sha256:844da2b5728b5ce0e32d863af26f32b5ce61bc4273a9c720a9f3aa9df73b1455 \ - --hash=sha256:8755483f3c00d6c9a77f490c17e6ab0c8729e39e6390328e42521ef175380ae6 \ - --hash=sha256:915f3849a011c1f593ab99092f3cecfcb4d65d8feb4a64cf1bf2d22074dc0ec4 \ - --hash=sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0 \ - --hash=sha256:982bb1e8b4ffda883b3d0a521e23abcd6fd17418f6d2c4118d257a10199c0ce3 \ - --hash=sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1 \ - --hash=sha256:9cbfacf36cb0ec2897ce0ebc5d08ca44213af24265bd56eca54bee7923c48fd6 \ - --hash=sha256:a370b3e078e418187da8c3674eddb9d983ec09445c99a3a263c2011993522981 \ - --hash=sha256:a955b438e62efdf7e0b7b52a64dc5c3396e2634baa62471768a64bc2adb73d5c \ - --hash=sha256:aa6af9e7d59f9c12b33ae4e9450619cf2488e2bbe9b44030905877f0b2324980 \ - --hash=sha256:aa88ca0b1932e93f2d961bf3addbb2db902198dca337d88c89e1559e066e7645 \ - --hash=sha256:aaeeb6a479c7667fbe1099af9617c83aaca22182d6cf8c53966491a0f1b7ffb7 \ - --hash=sha256:aaf27faa992bfee0264dc1f03f4c75e9fcdda66a519db6b957a3f826e285cf12 \ - --hash=sha256:b2680962a4848b3c4f155dc2ee64505a9c57186d0d56b43123b17ca3de18f0fa \ - --hash=sha256:b2d318c11350e10662026ad0eb71bb51c7812fc8590825304ae0bdd4ac283acd \ - --hash=sha256:b33de11b92e9f75a2b545d6e9b6f37e398d86c3e9e9653c4864eb7e89c5773ef \ - --hash=sha256:b3daeac64d5b371dea99714f08ffc2c208522ec6b06fbc7866a450dd446f5c0f \ - --hash=sha256:be1e352acbe3c78727a16a455126d9ff83ea2dfdcbc83148d2982305a04714c2 \ - --hash=sha256:bee093bf902e1d8fc0ac143c88902c3dfc8941f7ea1d6a8dd2bcb786d33db03d \ - --hash=sha256:c72fbbe68c6f32f251bdc08b8611c7b3060612236e960ef848e0a517ddbe76c5 \ - --hash=sha256:c9e36a97bee9b86ef9a1cf7bb96747eb7a15c2f22bdb5b516434b00f2a599f02 \ - --hash=sha256:cddf7bd982eaa998934a91f69d182aec997c6c468898efe6679af88283b498d3 \ - --hash=sha256:cf713fe9a71ef6fd5adf7a79670135081cd4431c2943864757f0fa3a65b1fafd \ - --hash=sha256:d11b54acf878eef558599658b0ffca78138c8c3655cf4f3a4a673c437e67732e \ - --hash=sha256:d41c4d287cfc69060fa91cae9683eacffad989f1a10811995fa309df656ec214 \ - --hash=sha256:d524ba3f1581b35c03cb42beebab4a13e6cdad7b36246bd22541fa585a56cccd \ - --hash=sha256:daac4765328a919a805fa5e2720f3e94767abd632ae410a9062dff5412bae65a \ - --hash=sha256:db4c7bf0e07fc3b7d89ac2a5880a6a8062056801b83ff56d8464b70f65482b6c \ - --hash=sha256:dc7039885fa1baf9be153a0626e337aa7ec8bf96b0128605fb0d77788ddc1681 \ - --hash=sha256:dccab8d5fa1ef9bfba0590ecf4d46df048d18ffe3eec01eeb73a42e0d9e7a8ba \ - --hash=sha256:dedb8adb91d11846ee08bec4c8236c8549ac721c245678282dcb06b221aab59f \ - --hash=sha256:e45ba65510e2647721e35323d6ef54c7974959f6081b58d4ef5d87c60c84919a \ - --hash=sha256:e53efc7c7cee4c1e70661e2e112ca46a575f90ed9ae3fef200f2a25e954f4b28 \ - --hash=sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691 \ - --hash=sha256:e70e990b2137b29dc5564715de1e12701815dacc1d056308e2b17e9095372a82 \ - --hash=sha256:e8082b26888e2f8b36a042a58307d5b917ef2b1cacab921ad3323ef91901c71a \ - --hash=sha256:e8323a9b031aa0393768b87f04b4164a40037fb2a3c11ac06a03ffecd3618027 \ - --hash=sha256:e92fca20c46e9f5e1bb485887d074918b13543b1c2a1185e69bb8d17ab6236a7 \ - --hash=sha256:eb30abc20df9ab0814b5a2524f23d75dcf83cde762c161917a2b4b7b55b1e518 \ - --hash=sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf \ - --hash=sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b \ - --hash=sha256:efd387a49825780ff861998cd959767800d54f8308936b21025326de4b5a42b9 \ - --hash=sha256:f0aa37f3c979cf2546b73e8222bbfa3dc07a641585340179d768068e3455e544 \ - --hash=sha256:f4074c5a429281bf056ddd4c5d3b740ebca4d43ffffe2ef4bf4d2d05114299da \ - --hash=sha256:f69a27e45c43520f5487f27627059b64aaf160415589230992cec34c5e18a509 \ - --hash=sha256:fb707f3e15060adf5b7ada797624a6c6e0138e2a26baa089df64c68ee98e040f \ - --hash=sha256:fcbe676a55d7445b22c10967bceaaf0ee69407fbe0ece4d032b6eb8d4565982a \ - --hash=sha256:fdb20a30fe1175ecabed17cbf7812f7b804b8a315a25f24678bcdf120a90077f +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 # via requests -idna==3.10 \ - --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ - --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests mpmath==1.3.0 \ --hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \ @@ -109,7 +111,7 @@ mpmath==1.3.0 \ networkx==2.8.8 \ --hash=sha256:230d388117af870fce5647a3c52401fcf753e94720e6ea6b4197a5355648885e \ --hash=sha256:e435dfa75b1d7195c7b8378c3859f0445cd88c6b0375c181ed66823a9ceb7524 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in numpy==1.26.4 \ --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ @@ -147,79 +149,79 @@ numpy==1.26.4 \ --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f - # via -r tools/build/bazel/requirements.in -pyyaml==6.0.2 \ - --hash=sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff \ - --hash=sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48 \ - --hash=sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086 \ - --hash=sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e \ - --hash=sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133 \ - --hash=sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5 \ - --hash=sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484 \ - --hash=sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee \ - --hash=sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5 \ - --hash=sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68 \ - --hash=sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a \ - --hash=sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf \ - --hash=sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99 \ - --hash=sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8 \ - --hash=sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85 \ - --hash=sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19 \ - --hash=sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc \ - --hash=sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a \ - --hash=sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1 \ - --hash=sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317 \ - --hash=sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c \ - --hash=sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631 \ - --hash=sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d \ - --hash=sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652 \ - --hash=sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5 \ - --hash=sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e \ - --hash=sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b \ - --hash=sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8 \ - --hash=sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476 \ - --hash=sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706 \ - --hash=sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563 \ - --hash=sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237 \ - --hash=sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b \ - --hash=sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083 \ - --hash=sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180 \ - --hash=sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425 \ - --hash=sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e \ - --hash=sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f \ - --hash=sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725 \ - --hash=sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183 \ - --hash=sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab \ - --hash=sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774 \ - --hash=sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725 \ - --hash=sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e \ - --hash=sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5 \ - --hash=sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d \ - --hash=sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290 \ - --hash=sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44 \ - --hash=sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed \ - --hash=sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4 \ - --hash=sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba \ - --hash=sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12 \ - --hash=sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in +pyyaml==6.0.1 \ + --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ + --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ + --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ + --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ + --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ + --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ + --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ + --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ + --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ + --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ + --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ + --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ + --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ + --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ + --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ + --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ + --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ + --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ + --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ + --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ + --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ + --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ + --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ + --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ + --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ + --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ + --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ + --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ + --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ + --hash=sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef \ + --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ + --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ + --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ + --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ + --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ + --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ + --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ + --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ + --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ + --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ + --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ + --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ + --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ + --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ + --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ + --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ + --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ + --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ + --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ + --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ + --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f + # via -r requirements.in requests==2.32.2 \ --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c - # via -r tools/build/bazel/requirements.in -setuptools==78.1.1 \ - --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ - --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d - # via -r tools/build/bazel/requirements.in + # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ --hash=sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8 - # via -r tools/build/bazel/requirements.in + # via -r requirements.in typing-extensions==4.11.0 \ --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a - # via -r tools/build/bazel/requirements.in + # via -r requirements.in urllib3==2.5.0 \ --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc # via requests + +# The following packages are considered to be unsafe in a requirements file: +setuptools==78.1.1 \ + --hash=sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561 \ + --hash=sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d + # via -r requirements.in diff --git a/tools/setup_helpers/build.bzl b/tools/setup_helpers/build.bzl index 5210b6d485552..c5be13e4603b4 100644 --- a/tools/setup_helpers/build.bzl +++ b/tools/setup_helpers/build.bzl @@ -4,7 +4,7 @@ def define_targets(rules): srcs = ["generate_code.py"], visibility = ["//:__pkg__"], deps = [ - rules.requirement("pyyaml"), + rules.requirement("PyYAML"), "//tools/autograd", "//torchgen", ], diff --git a/torchgen/build.bzl b/torchgen/build.bzl index 0adcf24e1a4c2..50765869f8d5d 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -4,7 +4,7 @@ def define_targets(rules): srcs = rules.glob(["**/*.py"]), visibility = ["//visibility:public"], deps = [ - rules.requirement("pyyaml"), + rules.requirement("PyYAML"), rules.requirement("typing-extensions"), ], ) @@ -14,7 +14,7 @@ def define_targets(rules): srcs = [":torchgen"], visibility = ["//visibility:public"], deps = [ - rules.requirement("pyyaml"), + rules.requirement("PyYAML"), rules.requirement("typing-extensions"), ], ) From f168cf49a8e2e81dcd4cd3c631325221a4f3faac Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Mon, 21 Jul 2025 15:19:27 +0000 Subject: [PATCH 317/457] [BE] Always use python 3.9 for pre-push hook's lintrunner (#158693) A follow up to https://github.com/pytorch/pytorch/pull/158389 Sets up the pre-push lintrunner to always use python 3.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158693 Approved by: https://github.com/atalman --- scripts/setup_hooks.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scripts/setup_hooks.py b/scripts/setup_hooks.py index 7c467befd9498..41f08d45e98b6 100644 --- a/scripts/setup_hooks.py +++ b/scripts/setup_hooks.py @@ -17,6 +17,7 @@ import subprocess import sys from pathlib import Path +from typing import Tuple # ─────────────────────────────────────────── @@ -64,7 +65,9 @@ def ensure_uv() -> None: ) -def ensure_tool_installed(tool: str, force_update: bool = False) -> None: +def ensure_tool_installed( + tool: str, force_update: bool = False, python_ver: Tuple[int, int] = None +) -> None: """ Checks to see if the tool is available and if not (or if force update requested) then it reinstalls it. @@ -74,7 +77,11 @@ def ensure_tool_installed(tool: str, force_update: bool = False) -> None: """ if force_update or not which(tool): print(f"Ensuring latest {tool} via uv …") - run(["uv", "tool", "install", "--force", tool]) + command = ["uv", "tool", "install", "--force", tool] + if python_ver: + # Add the Python version to the command if specified + command.extend(["--python", f"{python_ver[0]}.{python_ver[1]}"]) + run(command) if not which(tool): print( f"\n⚠️ {tool} installation succeed, but it's not on PATH. Launch a new terminal if your git pushes don't work.\n" @@ -94,7 +101,7 @@ def ensure_tool_installed(tool: str, force_update: bool = False) -> None: ensure_uv() # Ensure pre-commit is installed globally via uv -ensure_tool_installed("pre-commit", force_update=True) +ensure_tool_installed("pre-commit", force_update=True, python_ver=(3, 9)) # Don't force a lintrunner update because it might break folks # who already have it installed in a different way From 9894d43b6cba4bd30271c8dc065e5f3af4b6859d Mon Sep 17 00:00:00 2001 From: Xu Han Date: Mon, 21 Jul 2025 15:59:40 +0000 Subject: [PATCH 318/457] [AOTI] explicit aoti wrapper functions for Windows. (#158713) On Windows, we need to explicit declaration for export APIs. Because the package loader call these API via GetProcAddress. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158713 Approved by: https://github.com/desertfire --- torch/csrc/inductor/aoti_runtime/interface.h | 71 ++++++++++++-------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h index f2b29049bb811..fab9a87a725e8 100644 --- a/torch/csrc/inductor/aoti_runtime/interface.h +++ b/torch/csrc/inductor/aoti_runtime/interface.h @@ -6,6 +6,17 @@ // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +#ifdef _WIN32 +/* +On Windows, we need to explicit declaration for export APIs. And because the +package loader call these API via GetProcAddress(ldsym on Linux), we can ignore +the import case. +*/ +#define AOTI_API __declspec(dllexport) +#else +#define AOTI_API __attribute__((__visibility__("default"))) +#endif + extern "C" { struct AOTInductorModelOpaque; using AOTInductorModelHandle = AOTInductorModelOpaque*; @@ -21,7 +32,7 @@ using AOTInductorConstantMapHandle = AOTInductorConstantMap*; // TODO: Deprecate this API. This was kept for BC compatibility. // Please use AOTInductorModelContainerCreateWithDevice instead. -AOTIRuntimeError AOTInductorModelContainerCreate( +AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate( AOTInductorModelContainerHandle* container_handle, size_t num_models, bool is_cpu, @@ -34,18 +45,18 @@ AOTIRuntimeError AOTInductorModelContainerCreate( // "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA // device, runtime will use the device index returned by // "cudaGetDevice(&device_idx)" -AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( +AOTI_API AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( AOTInductorModelContainerHandle* container_handle, size_t num_models, const char* device_str, const char* cubin_dir); // Deletes the AOTInductor model container. -AOTIRuntimeError AOTInductorModelContainerDelete( +AOTI_API AOTIRuntimeError AOTInductorModelContainerDelete( AOTInductorModelContainerHandle container_handle); // Runs the inference. -AOTIRuntimeError AOTInductorModelContainerRun( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRun( AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -59,7 +70,7 @@ AOTIRuntimeError AOTInductorModelContainerRun( AOTIProxyExecutorHandle proxy_executor_handle); // Single-threaded variant of previous. -AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -73,14 +84,14 @@ AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( AOTIProxyExecutorHandle proxy_executor_handle); // Retrieves the number of constants for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumConstants( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumConstants( AOTInductorModelContainerHandle container_handle, size_t* num_constants); // Retrieves a constant's name. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantName( AOTInductorModelContainerHandle container_handle, size_t idx, const char** name); @@ -88,7 +99,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantName( // Retrieves a constant's original FQN. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( AOTInductorModelContainerHandle container_handle, size_t idx, const char** original_fqn); @@ -96,7 +107,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( // Retrieves whether a constant is from folded. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( AOTInductorModelContainerHandle container_handle, size_t idx, bool* from_folded); @@ -104,7 +115,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( // Retrieves the inductor constant type. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantType( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantType( AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* type); @@ -112,7 +123,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantType( // Retrieves a constant's dtype. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* dtype); @@ -120,20 +131,21 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( // Retrieves a constant's data size. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( AOTInductorModelContainerHandle container_handle, size_t idx, size_t* data_size); // Extract the constants that is being used in the container. -AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( +AOTI_API AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive); // Setup the constant buffer in model container with provided ConstantMap. // The ConstantMap is user managed, and the user would retain ownership. -AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( +AOTI_API AOTIRuntimeError +AOTInductorModelContainerUpdateUserManagedConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -142,7 +154,7 @@ AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( // Setup the constant buffer in model container with provided ConstantMap // use_inactive should be set as true if the inactive buffer is to be updated. // validate_full_update checks if all constants are included in the ConstantMap -AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -150,43 +162,43 @@ AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( // Setup the inactive constant buffer in model container with provided // ConstantMap -AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle); // Free the inactive constant buffer in model container. -AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( AOTInductorModelContainerHandle container_handle); // Run constant folding on constant buffer. -AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( AOTInductorModelContainerHandle container_handle, bool use_inactive, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle); // Swap the constant buffer being used to the inactive one. -AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( AOTInductorModelContainerHandle container_handle); // Retrieves the number of inputs for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumInputs( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumInputs( AOTInductorModelContainerHandle container_handle, size_t* ret_num_inputs); // Retrieves the input name at the given index. -AOTIRuntimeError AOTInductorModelContainerGetInputName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetInputName( AOTInductorModelContainerHandle container_handle, size_t input_idx, const char** ret_input_names); // Retrieves the number of outputs for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( AOTInductorModelContainerHandle container_handle, size_t* ret_num_outputs); // Retrieves the output name at the given index. -AOTIRuntimeError AOTInductorModelContainerGetOutputName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetOutputName( AOTInductorModelContainerHandle container_handle, size_t output_idx, const char** ret_output_names); @@ -198,31 +210,32 @@ AOTIRuntimeError AOTInductorModelContainerGetOutputName( // // constant_map_handle is an opaque type to satisfy the C ABI. It should be a // std::unordered_map*. -AOTIRuntimeError AOTInductorModelCreate( +AOTI_API AOTIRuntimeError AOTInductorModelCreate( AOTInductorModelHandle* model_handle, AOTInductorConstantMapHandle constant_map_handle); // Run an AOTInductorModel (see AOTInductorModelCreate for when one should use // this function versus AOTInductorModelContainerRun). -AOTIRuntimeError AOTInductorModelRun( +AOTI_API AOTIRuntimeError AOTInductorModelRun( AOTInductorModelHandle model_handle, AtenTensorHandle* input_handles, AtenTensorHandle* output_handles); // Replace AOTInductorModel's constant map. Note it doesn't handle concurrency // so be sure to handle ordering if AOTInductorModelRun is ran concurrently. -AOTIRuntimeError AOTInductorModelUpdateConstantsMap( +AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsMap( AOTInductorModelHandle model_handle, AOTInductorConstantMapHandle constant_map_handle); // Delete an AOTInductorModel created by AOTInductorModelCreate. -AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); +AOTI_API AOTIRuntimeError +AOTInductorModelDelete(AOTInductorModelHandle model_handle); -AOTIRuntimeError AOTInductorModelGetNumOutputs( +AOTI_API AOTIRuntimeError AOTInductorModelGetNumOutputs( AOTInductorModelHandle model_handle, size_t* ret_num_outputs); -AOTIRuntimeError AOTInductorModelContainerGetCallSpec( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetCallSpec( AOTInductorModelContainerHandle container_handle, const char** in_spec, const char** out_spec); From cbe1cb70183dd0d08dd555353eeca72399401ae8 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Mon, 21 Jul 2025 09:31:09 +0000 Subject: [PATCH 319/457] [CMake] Move xpu flag to xpu.cmake (#158542) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158542 Approved by: https://github.com/gujinghui, https://github.com/ezyang --- CMakeLists.txt | 4 ---- cmake/public/xpu.cmake | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d1f8a13fb9fd3..63a2f74404c1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1190,10 +1190,6 @@ if(APPLE) append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS) endif() -if(USE_XPU) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_XPU") -endif() - if(EMSCRIPTEN) string( APPEND diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index be083cb93af10..b39e31d0ade8a 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -11,6 +11,7 @@ set(XPU_HOST_CXX_FLAGS) find_package(SYCLToolkit REQUIRED) if(NOT SYCL_FOUND) set(PYTORCH_FOUND_XPU FALSE) + # Exit early to avoid populating XPU_HOST_CXX_FLAGS. return() endif() set(PYTORCH_FOUND_XPU TRUE) @@ -36,6 +37,8 @@ torch_xpu_get_arch_list(XPU_ARCH_FLAGS) # propagate to torch-xpu-ops set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) +# Ensure USE_XPU is enabled. +string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU") string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") if(DEFINED ENV{XPU_ENABLE_KINETO}) From 35f1b4ad9ef022ce59a1084fe237ceb35c7aab99 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 17:31:42 +0000 Subject: [PATCH 320/457] Revert "Fused RMSNorm implementation (#153666)" This reverts commit 15ef4f28df0a14e9f0d55a57a4e2db415a303be7. Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking tests internally. @albanD can you please help land this change?You can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts. See D78599667 for more info ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3097690935)) --- .../functorch/BatchRulesDecompositions.cpp | 1 - .../src/ATen/native/cuda/layer_norm_kernel.cu | 590 +++++------------- aten/src/ATen/native/layer_norm.cpp | 77 +-- aten/src/ATen/native/layer_norm.h | 6 - .../src/ATen/native/mps/operations/RMSNorm.mm | 13 +- aten/src/ATen/native/native_functions.yaml | 8 +- ...asDecompTest.test_has_decomposition.expect | 1 + .../check_forward_backward_compatibility.py | 2 - test/test_decomp.py | 29 +- tools/autograd/derivatives.yaml | 5 - torch/_decomp/__init__.py | 1 - torch/_decomp/decompositions.py | 75 --- torch/csrc/autograd/FunctionsManual.cpp | 189 ------ torch/csrc/autograd/FunctionsManual.h | 23 - .../aoti_torch/generated/c_shim_cpu.h | 1 - .../aoti_torch/generated/c_shim_cuda.h | 1 - .../aoti_torch/generated/c_shim_mps.h | 2 +- .../aoti_torch/generated/c_shim_xpu.h | 1 - torch/overrides.py | 1 - 19 files changed, 183 insertions(+), 843 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index d58d436c511d1..4b66b30b62e7f 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); - m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index f765b515cd0bc..bdb169e26b142 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -50,7 +50,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,17 +84,12 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - if constexpr (!rms_norm){ - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); - } else { - rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); - } - + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -108,15 +103,11 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - if constexpr (!rms_norm){ - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; - } else { - Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; - } + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; } } @@ -128,48 +119,40 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - if constexpr (!rms_norm){ - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; - } else{ - return {0.f, curr_sum.sigma2 + val * val, 0}; - } + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; } -template __device__ +__device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - if constexpr (!rms_norm){ - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; - } else { - mean = U(0); - sigma2 = U(0); - } - return {mean, sigma2, count}; + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; } else { - return {0.f, dataB.sigma2 + dataA.sigma2, 0}; + mean = U(0); + sigma2 = U(0); } + return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -188,13 +171,14 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), + WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -215,7 +199,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -232,7 +216,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -247,7 +231,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -270,48 +254,34 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); - } else { - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); - } + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); - } else { - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); - } + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - if constexpr (!rms_norm){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); - } else { - out.val[ii] = rstd_val * static_cast(data.val[ii]); - } + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); } } Y_vec[i] = out; } if (thrx == 0) { - if constexpr (!rms_norm){ - mean[i1] = wd.mean; - } + mean[i1] = wd.mean; rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -326,7 +296,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -336,11 +306,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -351,10 +321,7 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - T_ACC mean_val = 0; - if constexpr (!rms_norm){ - mean_val = mean[i1]; - } + const T_ACC mean_val = mean[i1]; const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -370,39 +337,26 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } - } - if constexpr (!rms_norm){ - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } + + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - if constexpr (!rms_norm){ - buf[0] = stats_x1; - } + buf[0] = stats_x1; buf[1] = stats_x2; } __syncthreads(); - if constexpr (!rms_norm){ - stats_x1 = buf[0]; - } + stats_x1 = buf[0]; stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -413,20 +367,15 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } - + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -438,7 +387,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -447,7 +396,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -460,10 +409,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - T_ACC mean_val = 0; - if constexpr (!rms_norm){ - mean_val = mean[bIdx]; - } + const T_ACC mean_val = mean[bIdx]; const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -495,12 +441,8 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else { - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } } @@ -509,29 +451,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - if constexpr (!rms_norm){ - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; - } else{ - stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; - } + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; } // Reduction in Shared Memory - if constexpr (!rms_norm){ - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); - } + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - if constexpr (!rms_norm){ - reduce_buf[0] = stats_x1; - } + reduce_buf[0] = stats_x1; reduce_buf[1] = stats_x2; } __syncthreads(); - if constexpr (!rms_norm){ - stats_x1 = reduce_buf[0]; - } + stats_x1 = reduce_buf[0]; stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -553,12 +485,8 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -573,19 +501,15 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - if constexpr (!rms_norm){ - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; - } else { - f_grad_input -= (x) * rstd_val * stats_x2; - } + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -601,25 +525,17 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - if constexpr (!rms_norm){ - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); - } else { - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index])) * static_cast(rstd[i]); - } + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - if constexpr (!rms_norm){ - db[j] = sum2; - } + db[j] = sum2; } } } @@ -629,8 +545,7 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y, -bool rms_norm> +bool check_y> __device__ __forceinline__ void @@ -654,9 +569,7 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - if constexpr (!rms_norm){ - warp_mean = mean[mean_index + lane_id]; - } + warp_mean = mean[mean_index + lane_id]; warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -683,14 +596,10 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - if constexpr (!rms_norm){ - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; - } else{ - dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; - } + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; } } @@ -699,8 +608,7 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y, -bool rms_norm> +bool check_y> __device__ __forceinline__ void @@ -721,10 +629,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -746,8 +654,7 @@ template __global__ void @@ -772,7 +679,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -780,11 +687,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -799,7 +706,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db && !rms_norm) { + if (db) { db[thread_y * N + thread_x] = db_sum; } } @@ -845,7 +752,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db && !rms_norm) { + if (db) { db[out_index] = reg_db; } } @@ -856,8 +763,7 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -873,7 +779,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -884,7 +790,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -900,7 +806,7 @@ if (aligned_grid) { template +int rows_per_block_y> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -923,16 +829,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -970,21 +876,19 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined() && !rms_norm) { + if (dbeta->defined()) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if constexpr (!rms_norm){ - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); - } + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); } } else { // We are in the normal case where M is not that large. @@ -992,18 +896,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -1032,7 +936,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1054,7 +958,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -1064,7 +968,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -1083,7 +987,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; + T_ACC* mean_data = mean->data_ptr(); T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1098,14 +1002,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1133,29 +1037,7 @@ void LayerNormKernelImpl( }); } -void RmsNormKernelImpl( - const Tensor& X, - const Tensor& gamma, - int64_t M, - int64_t N, - double eps, - Tensor* Y, - Tensor* rstd) { -AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "LayerNormKernelImpl", - [&]() { - using acc_t = acc_type; - // rms_norm = true - LayerNormKernelImplInternal( - // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True - X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); - }); -} - -template __device__ +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1173,10 +1055,7 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = 0; - if constexpr (!rms_norm){ - curr_mean = mean[i1]; - } + T_ACC curr_mean = mean[i1]; T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1201,7 +1080,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1219,11 +1098,7 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - - T_ACC curr_mean = 0; - if constexpr (!rms_norm){ - curr_mean = mean[i1]; - } + T_ACC curr_mean = mean[i1]; T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1239,7 +1114,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1265,9 +1140,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1306,7 +1181,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1331,9 +1206,7 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - if constexpr (!rms_norm){ - sum_beta += part_grad_beta_ptr[warp_offset*N]; - } + sum_beta += part_grad_beta_ptr[warp_offset*N]; } } @@ -1351,9 +1224,7 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - if constexpr (!rms_norm){ - sum_beta += buf[read_idx+nbsize3]; - } + sum_beta += buf[read_idx+nbsize3]; } __syncthreads(); } @@ -1364,14 +1235,12 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - if constexpr (!rms_norm){ - grad_beta[i2] = sum_beta; - } + grad_beta[i2] = sum_beta; } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1385,10 +1254,7 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = 0; - if constexpr (!rms_norm){ - c_mean = mean[i1]; - } + T_ACC c_mean = mean[i1]; const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1401,31 +1267,21 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - if constexpr (!rms_norm){ - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - } + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1436,33 +1292,25 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if constexpr (!rms_norm){ - buf[2*wrt_i] = sum_loss1; - } + buf[2*wrt_i] = sum_loss1; buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if constexpr (!rms_norm){ - sum_loss1 += buf[2*read_i]; - } + sum_loss1 += buf[2*read_i]; sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - if constexpr (!rms_norm){ - buf[2*threadIdx.x] = sum_loss1; - } + buf[2*threadIdx.x] = sum_loss1; buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - if constexpr (!rms_norm){ - sum_loss1 = buf[2*threadIdx.x]; - } + sum_loss1 = buf[2*threadIdx.x]; sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1475,12 +1323,8 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - if constexpr (!rms_norm){ - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; - } else { - f_grad_input -= (c_h) * c_rstd * sum_loss2; - } + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1489,12 +1333,8 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - if constexpr (!rms_norm){ - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; - } else { - f_grad_input -= (c_h) * c_rstd * sum_loss2; - } + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1504,7 +1344,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1518,9 +1358,7 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - if constexpr (!rms_norm){ - TORCH_CHECK(mean.numel() == M); - } + TORCH_CHECK(mean.numel() == M); TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1546,7 +1384,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1558,7 +1396,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1572,12 +1410,13 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); + if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1593,7 +1432,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1617,7 +1456,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1631,7 +1470,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1641,7 +1480,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1669,29 +1508,8 @@ void LayerNormBackwardKernelImpl( }); } -void RMSNormBackwardKernelImpl( - const Tensor& dY, - const Tensor& X, - const Tensor& rstd, - const Tensor& gamma, - int64_t M, - int64_t N, - Tensor* dX, - Tensor* dgamma) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "LayerNormBackwardKernelImpl", - [&]() { - LayerNormBackwardKernelImplInternal( - dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); - }); -} - } // namespace - std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1820,108 +1638,6 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } -/* RMSNorm is implemented by reusing layer_norm's kernels */ -std::tuple _fused_rms_norm_cuda( - const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - std::optional eps){ - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - - double eps_val = eps.value_or(std::numeric_limits::epsilon()); - - Tensor Y = at::native::empty_like( - *X, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); - Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); - - if (M > 0) { - RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); - } - - const auto input_shape = input.sizes(); - const size_t axis = input.dim() - normalized_shape.size(); - - std::vector stat_shape; - for (const auto idx: c10::irange(axis)) { - stat_shape.push_back(input_shape[idx]); - } - for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { - stat_shape.push_back(1); - } - - rstd = rstd.view(stat_shape); - - return std::make_tuple(std::move(Y), std::move(rstd)); -} - - -std::tuple _fused_rms_norm_backward_cuda( - const Tensor& dY, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt /* optional */, - std::array grad_input_mask) { - - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); - auto M = M_N.first; - auto N = M_N.second; - auto X = input.expect_contiguous(); - auto gamma = weight.expect_contiguous(); - - Tensor dX; - Tensor dgamma; - if (grad_input_mask[0]) { - dX = at::native::empty_like( - *X, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - if (grad_input_mask[1]) { - dgamma = M > 0 ? at::native::empty_like( - *gamma, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT) - : at::native::zeros_like( - *gamma, - std::nullopt /* dtype */, - std::nullopt /* layout */, - std::nullopt /* device */, - std::nullopt /* pin_memory */, - LEGACY_CONTIGUOUS_MEMORY_FORMAT); - } - - if (M > 0 && N > 0) { - RMSNormBackwardKernelImpl( - dY, *X, rstd, *gamma, M, N, &dX, &dgamma); - } - return std::make_tuple(std::move(dX), std::move(dgamma)); -} - REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 207f092a676a7..da6bb5fec39e8 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,11 +261,30 @@ std::tuple math_native_layer_norm( return outputs; } -std::tuple rms_norm_composite( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + +#ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_nested = input.is_nested() || weight.is_nested(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + const bool is_input_fp = isFloatingType(input.scalar_type()); + const bool is_weight_fp = isFloatingType(weight.scalar_type()); + + if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); + } + } +#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -302,60 +321,10 @@ std::tuple rms_norm_composite( upcasted_result = upcasted_result.mul(weight_opt.value()); } - // if nested do not make contiguous - if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ - return std::make_tuple(upcasted_result, rqrst_input); - } - - if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ - return std::make_tuple(upcasted_result, rqrst_input); - } - - return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); + return upcasted_result; }); - return std::make_tuple( - std::get<0>(result).type_as(input), // Cast normalized result to original input type - std::get<1>(result) // rsqrt_val - ); -} - -Tensor rms_norm_symint( - const Tensor& input, - c10::SymIntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - const std::optional eps) { - - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - - // composite fallback for channels last - if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - - // composite fallback for complex datatypes - if(input.is_complex()){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - - #ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + return result.type_as(input); - if (!(GradMode::is_enabled() && any_inputs_require_grad)) { - return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - } - - if (input.device().type() == DeviceType::MPS){ - return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); - } - #endif - - return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); } - } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0debe942dd0a6..0181f35fd6ed4 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,12 +106,6 @@ void layer_norm_cpu_out( int64_t M, int64_t N); -std::tuple rms_norm_composite( - const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt /* optional */, - std::optional eps); - Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 7948b5acd8e93..71128297d5bfc 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,14 +19,7 @@ #include #endif -std::tuple _fused_rms_norm_mps(const Tensor& input, - IntArrayRef normalized_shape, - const std::optional& weight_opt, - const std::optional eps) { - const Tensor weight = weight_opt.value().contiguous(); - const int64_t normalized_ndim = normalized_shape.size(); - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - +Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -48,7 +41,7 @@ const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -65,7 +58,7 @@ } }); - return std::make_tuple(output, Tensor()); + return output; } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ce13e03fb9f6c..79b7e07e2284b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3314,15 +3314,9 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) +- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor dispatch: - CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps - CompositeImplicitAutograd: rms_norm_composite - -- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) - dispatch: - CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index a590713ad0f83..042959c22cd4a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -374,6 +374,7 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional +aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 5a962dfa57c05..d6cf2df4343ff 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,8 +139,6 @@ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), - # Previously MPS_only did not support backward - ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_decomp.py b/test/test_decomp.py index dcd6e69af997c..5d641e32e422e 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import SM70OrLater, tf32_off +from torch.testing._internal.common_cuda import tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,33 +1226,6 @@ def f(x, w, b): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) - @onlyCUDA - @unittest.skipIf(not SM70OrLater, "triton") - def test_rms_norm_decomp_cuda(self, device): - @torch.compile - def rms_norm_sinh(a, b, c): - output = torch.nn.functional.rms_norm(a, b, c) - return torch.sinh(output) - - normalized_shape_arg = (3, 3, 3) - input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - - def forward_pass_fn(): - return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) - - model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( - forward_pass_fn - ) - - # check RMSNorm was fused with sinh - self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] - ) - instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f0349c2484b61..e2419aab268b1 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,11 +1267,6 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") -- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) - input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" - result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) - result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) - - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 8e9796d2f7c1b..abb94b109cc0c 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, - aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 832928ebf8aee..f93a0bf84fb4b 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out( return grad_input -@register_decomposition(aten._fused_rms_norm_backward.default) -def _fused_rms_norm_backward( - grad_out: Tensor, - input: Tensor, - normalized_shape: list[int], - rstd: Tensor, - weight: Optional[Tensor], - output_mask: list[bool], -) -> tuple[Optional[Tensor], Optional[Tensor]]: - input_shape = input.shape - input_ndim = input.dim() - computation_dtype = utils.get_computation_dtype(input.dtype) - - grad_out_cast = grad_out.to( - computation_dtype, memory_format=torch.contiguous_format - ) - input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) - weight_cast = ( - weight.to(computation_dtype, memory_format=torch.contiguous_format) - if weight is not None - else None - ) - assert grad_out_cast is not None - - axis = input_ndim - len(normalized_shape) - inner_dims = input_shape[axis:] - outer_dims = input_shape[:axis] - inner_dim_indices: list[int] = [] - outer_dim_indices: list[int] = [] - for i in range(input_ndim): - if i >= axis: - inner_dim_indices.append(i) - else: - outer_dim_indices.append(i) - - N = prod(inner_dims) # type: ignore[arg-type] - M = prod(outer_dims) # type: ignore[arg-type] - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): - return ( - input.new_zeros(input_shape) if output_mask[0] else None, - input.new_zeros(input_shape[axis:]) if output_mask[1] else None, - ) - - rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] - if weight_cast is not None: - grad_x_hat = grad_out_cast * weight_cast - else: - grad_x_hat = grad_out_cast - - d_input: Optional[Tensor] = None - d_weight: Optional[Tensor] = None - - x_hat = input_cast * rstd - - if output_mask[0]: - sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) - d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd - - if output_mask[1] and weight_cast is not None: - d_weight_full_shape = grad_out_cast * x_hat - if len(outer_dim_indices) > 0: - d_weight = torch.sum( - d_weight_full_shape, dim=outer_dim_indices, keepdim=False - ) - else: - d_weight = d_weight_full_shape - - return ( - _maybe_cast(d_input, input.dtype), - _maybe_cast(d_weight, input.dtype), - ) - - def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 8e13d4267edb5..908a980cfee9c 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5023,103 +5023,6 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } -std::tuple infinitely_differentiable_native_rms_norm_backward( - const Tensor& dY, - const Tensor& drstd, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt, - std::array grad_input_mask) { - c10::MaybeOwned weight_maybe_owned = - at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - - const auto input_shape = input.sizes(); - const auto input_ndim = input.dim(); - const int normalized_ndim = normalized_shape.size(); - const int axis = input_ndim - normalized_ndim; - - int64_t N_rms = 1; - for (int i = 0; i < normalized_ndim; ++i) { - N_rms *= input_shape[axis + i]; - } - - Tensor dX; - Tensor dgamma; - - std::vector rstd_view_shape = rstd.sizes().vec(); - for (int i = 0; - i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); - ++i) { - rstd_view_shape.push_back(1); - } - Tensor rstd_broadcast = rstd.view(rstd_view_shape); - Tensor rstd_pow3 = rstd_broadcast.pow(3); - Tensor grad_x_hat; - - if (dY.defined()) { - if (weight.defined()) { - grad_x_hat = dY * weight; - } else { - grad_x_hat = dY; - } - } - - if (grad_input_mask[0]) { - Tensor dX_from_dY_path; - Tensor dX_from_drstd_path; - - std::vector inner_sum_dims; - inner_sum_dims.reserve(normalized_ndim); - for (int i = 0; i < normalized_ndim; ++i) { - inner_sum_dims.push_back(axis + i); - } - - if (dY.defined() && grad_x_hat.defined()) { - Tensor sum_input_times_grad_x_hat = - sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); - dX_from_dY_path = rstd_broadcast * grad_x_hat - - (input * rstd_pow3 / static_cast(N_rms)) * - sum_input_times_grad_x_hat; - } - - if (drstd.defined()) { - Tensor drstd_broadcast = drstd.view(rstd_view_shape); - dX_from_drstd_path = - -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; - } - - if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { - dX = dX_from_dY_path + dX_from_drstd_path; - } else if (dX_from_dY_path.defined()) { - dX = dX_from_dY_path; - } else if (dX_from_drstd_path.defined()) { - dX = dX_from_drstd_path; - } - } - - if (grad_input_mask[1] && weight.defined()) { - if (dY.defined()) { - Tensor x_hat = input * rstd_broadcast; - Tensor dgamma_full_shape = dY * x_hat; - - if (axis > 0) { - std::vector outer_sum_dims; - outer_sum_dims.reserve(axis); - for (int i = 0; i < axis; ++i) { - outer_sum_dims.push_back(i); - } - dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); - } else { - dgamma = dgamma_full_shape; - } - } - } - - return std::make_tuple(dX, dgamma); -} - std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6474,98 +6377,6 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } -Tensor rms_norm_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& weight_p, - const Tensor& weight_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape) { - auto dims = std::vector{}; - auto view_size = input_t.sizes().vec(); - auto view_size_affine = input_t.sizes().vec(); - - int64_t numel = 1; - for (const auto i : c10::irange(view_size.size())) { - if (i < view_size.size() - normalized_shape.size()) { - view_size_affine[i] = 1; - } else { - numel *= input_t.size(static_cast(i)); - view_size[i] = 1; - dims.push_back(static_cast(i)); - } - } - - auto rstd_p = saved_rstd.view(view_size); - - Tensor rstd_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); - } else { - rstd_t = input_t * input_p; - rstd_t *= -rstd_p.pow(3); - } - rstd_t = rstd_t.sum(dims, true); - rstd_t /= numel; - - Tensor result_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - result_t = (input_t)*rstd_p + (input_p)*rstd_t; - } else { - result_t = input_t * rstd_p; - auto temp = input_p * rstd_t; - result_t += temp; - } - - std::optional result_p = std::nullopt; - if (weight_p.defined()) { - result_p = std::optional(input_p * rstd_p); - } - - return _affine_jvp( - result_p, - result_t, - weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, - weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, - Tensor()); -} - -Tensor rms_norm_rstd_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape) { - auto dims = std::vector{}; - auto view_size = input_t.sizes().vec(); - auto view_size_affine = input_t.sizes().vec(); - - int64_t numel = 1; - for (const auto i : c10::irange(view_size.size())) { - if (i < view_size.size() - normalized_shape.size()) { - view_size_affine[i] = 1; - } else { - numel *= input_t.size(static_cast(i)); - view_size[i] = 1; - dims.push_back(static_cast(i)); - } - } - - auto rstd_p = saved_rstd.view(view_size); - Tensor rstd_t; - if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || - input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); - } else { - rstd_t = input_t * input_p; - rstd_t *= -rstd_p.pow(3); - } - rstd_t = rstd_t.sum(dims, true); - rstd_t /= numel; - return rstd_t; -} - Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 96864e165a95a..0b659973ec345 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -826,15 +826,6 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); -std::tuple infinitely_differentiable_native_rms_norm_backward( - const Tensor& dY, - const Tensor& drstd, - const Tensor& input, - IntArrayRef normalized_shape, - const Tensor& rstd, - const std::optional& weight_opt, - std::array grad_input_mask); - std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -974,20 +965,6 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); -Tensor rms_norm_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& weight_p, - const Tensor& weight_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape); - -Tensor rms_norm_rstd_jvp( - const Tensor& input_p, - const Tensor& input_t, - const Tensor& saved_rstd, - IntArrayRef normalized_shape); - Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index aced2b2f539de..2aa09cb802ecd 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,7 +29,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 92d30ded855f8..e0607f984b3d0 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,7 +32,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index c76ee685c25da..a5d654c518840 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 6fc51bd0c8f8d..243bfb5fc87aa 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,7 +13,6 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/overrides.py b/torch/overrides.py index 046171ef6c5c6..cb67931fab691 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,7 +820,6 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, - torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1, From 7205458b85aca1377f73bdbacd8a9cd83b2eebbc Mon Sep 17 00:00:00 2001 From: FFFrog Date: Mon, 21 Jul 2025 09:52:44 +0800 Subject: [PATCH 321/457] [Easy] Show some clear error when torch.ops.load_library fails. (#157524) **Background**: ```Shell torch 2.5.1+cpu torchvision 0.20.1 ``` ```Python import torch import torchvision Traceback (most recent call last): File "", line 1, in File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torchvision/__init__.py", line 10, in from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torchvision/_meta_registrations.py", line 164, in def meta_nms(dets, scores, iou_threshold): File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/library.py", line 795, in register use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/library.py", line 184, in _register_fake handle = entry.fake_impl.register(func_to_register, source) File "/usr/local/anaconda3/envs/test/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 31, in register if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): RuntimeError: operator torchvision::nms does not exist ``` **Cause**: ``` torchvision's .so file lacks some symbol definitions, because these symbols come from CUDA, but the current environment does not have CUDA and GPU. The above error message is very confusing. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157524 Approved by: https://github.com/ezyang --- torch/_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_ops.py b/torch/_ops.py index 4ccb9d7be6550..83a5dc0e57a5e 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1483,7 +1483,10 @@ def load_library(self, path): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom # operators with the JIT. - ctypes.CDLL(path) + try: + ctypes.CDLL(path) + except Exception as e: + raise OSError(f"Could not load this library: {path}") from e self.loaded_libraries.add(path) From a78fb63dbdf98a1db219095293de1a11005e0390 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 21 Jul 2025 14:26:45 +0800 Subject: [PATCH 322/457] [build] pin `setuptools>=77` to enable PEP 639 (#158104) For reference here is the link PEP 639: [peps.python.org/pep-0639](https://peps.python.org/pep-0639/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158104 Approved by: https://github.com/rgommers, https://github.com/Skylion007, https://github.com/atalman --- .ci/aarch64_linux/aarch64_ci_setup.sh | 2 +- .ci/docker/manywheel/Dockerfile_2_28 | 2 +- .ci/docker/manywheel/Dockerfile_s390x | 5 ++--- .ci/docker/requirements-ci.txt | 11 ++++++----- .ci/pytorch/build.sh | 3 +++ .ci/pytorch/test.sh | 2 +- .ci/pytorch/win-test-helpers/build_pytorch.bat | 5 +++++ .ci/pytorch/win-test.sh | 2 +- .ci/pytorch/windows/internal/install_python.bat | 2 +- .ci/pytorch/windows/setup_build.bat | 5 ++++- .ci/wheel/build_wheel.sh | 14 +++++++------- .github/requirements-gha-cache.txt | 2 +- .github/requirements/pip-requirements-macOS.txt | 8 ++++---- .github/scripts/lintrunner.sh | 2 +- .github/scripts/windows/build_triton.bat | 2 +- .github/workflows/_mac-test.yml | 5 +++++ pyproject.toml | 11 +++-------- requirements-build.txt | 4 ++-- test/dynamo/test_exc.py | 16 ++++++++-------- 19 files changed, 57 insertions(+), 46 deletions(-) diff --git a/.ci/aarch64_linux/aarch64_ci_setup.sh b/.ci/aarch64_linux/aarch64_ci_setup.sh index 8ffba65d7fedd..b18d27f2793fc 100755 --- a/.ci/aarch64_linux/aarch64_ci_setup.sh +++ b/.ci/aarch64_linux/aarch64_ci_setup.sh @@ -12,7 +12,7 @@ fi SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" source $SCRIPTPATH/../manywheel/set_desired_python.sh -pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2 +pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1.4 patchelf==0.17.2 for tool in python python3 pip pip3 ninja scons patchelf; do ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin; diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index b150423e99544..7f279a1c1a735 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -128,7 +128,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH # Install setuptools and wheel for python 3.12/3.13 RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ - /opt/python/${cpython_version}/bin/python -m pip install setuptools wheel; \ + /opt/python/${cpython_version}/bin/python -m pip install "setuptools>=77.0.0" "packaging>=24.2" wheel; \ done; diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 46ec7f77ae8ba..335488b88f122 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -124,10 +124,9 @@ RUN python3 -mpip install cmake==3.28.0 # install newest flatbuffers version first: # for some reason old version is getting pulled in otherwise. # packaging package is required for onnxruntime wheel build. -RUN pip3 install flatbuffers && \ - pip3 install cython 'pkgconfig>=1.5.5' 'setuptools>=77' 'numpy<2.3.0' && \ +RUN pip3 install 'setuptools>=77.0' 'packaging>=24.2' && \ + pip3 install flatbuffers cython 'pkgconfig>=1.5.5' 'numpy<2.3.0' && \ pip3 install --no-build-isolation h5py==3.11.0 && \ - pip3 install packaging && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index fb773ff324af8..a7486c40b121d 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -50,7 +50,7 @@ flatbuffers==24.12.23 hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests -#Pinned versions: 3.44.6, 4.53.2 +#Pinned versions: 5.35.1 #test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py junitparser==2.1.1 @@ -104,10 +104,10 @@ networkx==2.8.8 #Pinned versions: 2.8.8 #test that import: functorch -ninja==1.11.1.3 +ninja==1.11.1.4 #Description: build system. Used in some tests. Used in build to generate build #time tracing information -#Pinned versions: 1.11.1.3 +#Pinned versions: 1.11.1.4 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py numba==0.49.0 ; python_version < "3.9" @@ -307,7 +307,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.12.6.0 +z3-solver==4.15.1.0 #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -363,9 +363,10 @@ pwlf==2.2.1 # To build PyTorch itself +packaging>=24.2 pyyaml pyzstd -setuptools>=70.1.0 +setuptools>=77.0.0 six scons==4.5.2 ; platform_machine == "aarch64" diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 58454bcb108a7..f2b8998a6f6cd 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -269,6 +269,9 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" //... fi else + # install build-system requirements before running setup.py commands + python -m pip install -r requirements-build.txt + # check that setup.py would fail with bad arguments echo "The next three invocations are expected to fail with invalid command error messages." ( ! get_exit_code python setup.py bad_argument ) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index ad6a48b2528e4..2e7cc84138cee 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -201,7 +201,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then # JIT C++ extensions require ninja. - pip_install "ninja==1.10.2" + pip_install "ninja==1.11.1.4" # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 7ceb425ce2d1a..74c9183f2abb0 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -126,6 +126,11 @@ if "%USE_CUDA%"=="1" ( set CMAKE_CUDA_COMPILER_LAUNCHER=%TMP_DIR%/bin/randomtemp.exe;%TMP_DIR%\bin\sccache.exe ) +:: Install build-system requirements before running setup.py commands +python -m pip install -r requirements-build.txt +if errorlevel 1 goto fail +if not errorlevel 0 goto fail + :: Print all existing environment variable for debugging set diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index b61dd06ef562c..be7f3e4bb35cc 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -41,7 +41,7 @@ fi python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 # Install Z3 optional dependency for Windows builds. -python -m pip install z3-solver==4.12.2.0 +python -m pip install z3-solver==4.15.1.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.30 diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 73622bd736edd..65405a875b6b8 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -18,5 +18,5 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" -%PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel +%PYTHON_EXEC% -m pip install --upgrade pip "setuptools>=77.0.0" "packaging>=24.2" wheel if errorlevel 1 exit /b 1 diff --git a/.ci/pytorch/windows/setup_build.bat b/.ci/pytorch/windows/setup_build.bat index 9b492eef664d7..df925b4ba90bc 100644 --- a/.ci/pytorch/windows/setup_build.bat +++ b/.ci/pytorch/windows/setup_build.bat @@ -7,6 +7,9 @@ call "internal\install_python.bat" %PYTHON_EXEC% --version set "PATH=%CD%\Python\Lib\site-packages\cmake\data\bin;%CD%\Python\Scripts;%CD%\Python;%PATH%" + +%PYTHON_EXEC% -m pip install "setuptools>=77.0.0" "packaging>=24.2" + if "%DESIRED_PYTHON%" == "3.13t" %PYTHON_EXEC% -m pip install numpy==2.2.1 cmake if "%DESIRED_PYTHON%" == "3.13" %PYTHON_EXEC% -m pip install numpy==2.1.2 cmake if "%DESIRED_PYTHON%" == "3.12" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake @@ -16,7 +19,7 @@ if "%DESIRED_PYTHON%" == "3.9" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake %PYTHON_EXEC% -m pip install pyyaml %PYTHON_EXEC% -m pip install mkl-include mkl-static -%PYTHON_EXEC% -m pip install boto3 ninja typing_extensions setuptools==72.1.0 +%PYTHON_EXEC% -m pip install boto3 ninja typing-extensions where cmake.exe diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 878d6595c84c0..dc44f8ccc2922 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -127,7 +127,7 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages export MACOSX_DEPLOYMENT_TARGET=10.15 export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -SETUPTOOLS_PINNED_VERSION="==70.1.0" +SETUPTOOLS_PINNED_VERSION="==77.0.0" PYYAML_PINNED_VERSION="=5.3" EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -135,7 +135,7 @@ RENAME_WHEEL=true case $desired_python in 3.13t) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" CONDA_ENV_CREATE_FLAGS="python-freethreading" @@ -145,31 +145,31 @@ case $desired_python in ;; 3.13) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" ;; 3.12) echo "Using 3.12 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.11) echo "Using 3.11 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.10) echo "Using 3.10 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.9) echo "Using 3.9 deps" - SETUPTOOLS_PINNED_VERSION=">=70.1.0" + SETUPTOOLS_PINNED_VERSION=">=77.0.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 5c691e4bf9b31..381bccbee847d 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -8,7 +8,7 @@ boto3==1.35.42 jinja2==3.1.6 lintrunner==0.10.7 -ninja==1.10.0.post1 +ninja==1.11.1.4 nvidia-ml-py==11.525.84 pyyaml==6.0 requests==2.32.4 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 9c72c71523b7d..ea005956aefa5 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -7,12 +7,12 @@ hypothesis==6.56.4 librosa>=0.6.2 mpmath==1.3.0 networkx==2.8.7 -ninja==1.10.2.4 +ninja==1.11.1.4 numba==0.59.0 numpy==1.26.4 opt-einsum>=3.3 optree==0.13.0 -packaging==23.1 +packaging==25.0 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 @@ -26,11 +26,11 @@ pytest-xdist==3.3.1 pytest==7.3.2 pyyaml==6.0.2 scipy==1.12.0 -setuptools==72.1.0 +setuptools==80.9.0 sympy==1.13.3 tlparse==0.3.30 tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 -z3-solver==4.12.2.0 +z3-solver==4.15.1.0 diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index ef4741444f942..1411ff0397b53 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -2,7 +2,7 @@ set -ex # Use uv to speed up lintrunner init -python3 -m pip install uv==0.1.45 setuptools +python3 -m pip install -U uv setuptools CACHE_DIRECTORY="/tmp/.lintbin" # Try to recover the cached binaries diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index 97cd535a49889..da2e86b40432a 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -10,7 +10,7 @@ if "%PY_VERS%" == "3.13t" ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) :: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==78.1.1 ninja dir "%VC_INSTALL_PATH%" diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 063c97e449c75..8822aaf7df418 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -80,6 +80,11 @@ jobs: run: | sysctl machdep.cpu.brand_string kern.osproductversion + - name: Install build toolchain + run: | + brew update --quiet + brew install --formula cmake ninja + - name: Clean up leftover processes on MacOS pet runner continue-on-error: true run: | diff --git a/pyproject.toml b/pyproject.toml index b41ae87621f0f..133da9289f5c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,12 @@ [build-system] requires = [ - # 70.1.0: min version for integrated bdist_wheel command from wheel package # 77.0.0: min version for SPDX expression support for project.license - "setuptools>=70.1.0,<80.0", + "setuptools>=77.0.0,<80.0", "cmake>=3.27", "ninja", "numpy", - "packaging", + "packaging>=24.2", "pyyaml", "requests", "six", # dependency chain: NNPACK -> PeachPy -> six @@ -21,11 +20,7 @@ name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" requires-python = ">=3.9,<3.14" -# TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 -# FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. -# TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed -# to an error on 2026.02.18. See also: https://github.com/pypa/setuptools/issues/4903 -license = { text = "BSD-3-Clause" } +license = "BSD-3-Clause" authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] keywords = ["pytorch", "machine learning"] classifiers = [ diff --git a/requirements-build.txt b/requirements-build.txt index be19d987f73db..12332b0e1af01 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,9 +1,9 @@ # Build System requirements -setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 +setuptools>=77.0.0,<80.0 # setuptools develop deprecated on 80.0 cmake>=3.27 ninja numpy -packaging +packaging>=24.2 pyyaml requests six # dependency chain: NNPACK -> PeachPy -> six diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index acc3fd55f6fb0..c340a2882d471 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -251,13 +251,13 @@ def fn(x, shape): Model: ==> L['shape'][0]: 0 - ==> L['shape'][1]: 1 - ==> L['shape'][2]: 1 + ==> L['shape'][1]: 0 + ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 - ==> s3: 1 - ==> s52: 1 + ==> s3: 0 + ==> s52: 0 ==> s77: 3 ==> s86: 0 @@ -315,16 +315,16 @@ def fn(x, shape): %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) Model: - ==> L['shape'][0]: 1 - ==> L['shape'][1]: 1 + ==> L['shape'][0]: 0 + ==> L['shape'][1]: 0 ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s3: 0 - ==> s52: 1 + ==> s52: 0 ==> s77: 3 - ==> s86: 1 + ==> s86: 0 Assertions: ==> (== 0 L['x'].storage_offset()) From 637e75433cb4ffc61a057fbdd5597db501cb05a7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 15 Jul 2025 13:42:34 +0800 Subject: [PATCH 323/457] [BE] always use `uv pip` if possible in `pip_init.py` for `lintrunner init` (#157199) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157199 Approved by: https://github.com/ezyang, https://github.com/ZainRizvi --- tools/linter/adapters/pip_init.py | 53 ++++++++++++++++++++----------- tools/nightly.py | 4 +-- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index fbf9808e9b267..137e4637bdb44 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -13,17 +13,20 @@ import time -def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: +def run_command( + args: list[str], + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: - return subprocess.run(args, check=True) + return subprocess.run(args, env=env, text=True, encoding="utf-8", check=True) finally: end_time = time.monotonic() logging.debug("took %dms", (end_time - start_time) * 1000) -if __name__ == "__main__": +def main() -> None: parser = argparse.ArgumentParser(description="pip initializer") parser.add_argument( "packages", @@ -52,17 +55,16 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: stream=sys.stderr, ) - uv_available = ( - any(prefix in sys.base_prefix for prefix in ["uv/python", "uv\\python"]) - and shutil.which("uv") is not None - ) - - if uv_available: - pip_args = ["uv", "pip", "install"] - elif sys.executable: - pip_args = [sys.executable, "-mpip", "install"] - else: - pip_args = ["pip3", "install"] + env: dict[str, str] = { + **os.environ, + "UV_PYTHON": sys.executable, + "UV_PYTHON_DOWNLOADS": "never", + "FORCE_COLOR": "1", + "CLICOLOR_FORCE": "1", + } + uv_index = env.get("UV_INDEX", env.get("PIP_EXTRA_INDEX_URL")) + if uv_index: + env["UV_INDEX"] = uv_index # If we are in a global install, use `--user` to install so that you do not # need root access in order to initialize linters. @@ -70,9 +72,20 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: # However, `pip install --user` interacts poorly with virtualenvs (see: # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in # these cases perform a regular installation. - in_conda = os.environ.get("CONDA_PREFIX") is not None - in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None - if not in_conda and not in_virtualenv: + in_conda = env.get("CONDA_PREFIX") is not None + in_virtualenv = env.get("VIRTUAL_ENV") is not None + need_user_flag = not in_conda and not in_virtualenv + + uv: str | None = shutil.which("uv") + is_uv_managed_python = "uv/python" in sys.base_prefix.replace("\\", "/") + if uv and (is_uv_managed_python or not need_user_flag): + pip_args = [uv, "pip", "install"] + elif sys.executable: + pip_args = [sys.executable, "-mpip", "install"] + else: + pip_args = ["pip3", "install"] + + if need_user_flag: pip_args.append("--user") pip_args.extend(args.packages) @@ -92,4 +105,8 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: print(f"Would have run: {pip_args}") sys.exit(0) - run_command(pip_args) + run_command(pip_args, env=env) + + +if __name__ == "__main__": + main() diff --git a/tools/nightly.py b/tools/nightly.py index 8409173e8b5b6..0ed8cfe165aa9 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -250,6 +250,7 @@ def __init__( self._env = { "PIP_EXTRA_INDEX_URL": self.pip_source.index_url, "UV_INDEX": self.pip_source.index_url, + "UV_PYTHON_DOWNLOADS": "never", "FORCE_COLOR": "1", "CLICOLOR_FORCE": "1", } @@ -475,13 +476,12 @@ def uv( cmd = [str(self.bindir / "uv"), *args] env = popen_kwargs.pop("env", None) or {} check = popen_kwargs.pop("check", True) - env["UV_PYTHON"] = str(python) return subprocess.run( cmd, check=check, text=True, encoding="utf-8", - env={**self._env, **env}, + env={**self._env, **env, "UV_PYTHON": str(python)}, **popen_kwargs, ) From 9285b8245c0ce17cd4e3b7e09e5f908693fb93ca Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 18 Jul 2025 15:13:09 -0700 Subject: [PATCH 324/457] [BE][testing] fix test_cat_max_autotune_triton (#158589) Summary: This test often fails internally -- looks like it's because autotuning sometimes chooses not to do the epilog tuning. Turning off `benchmark_epilogue_fusion` seems to fix. Test Plan: `buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:max_autotune -- --exact 'caffe2/test/inductor:max_autotune - test_cat_max_autotune_triton (caffe2.test.inductor.test_max_autotune.TestMaxAutotune)' --run-disabled` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158589 Approved by: https://github.com/eellison --- test/inductor/test_max_autotune.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 9f40c7d3d23e6..6245b89f4eca3 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -752,7 +752,12 @@ def test_cat_max_autotune_extern(self): @skipIfXpu( msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) - @config.patch(max_autotune_gemm_backends="TRITON") + @config.patch( + { + "max_autotune_gemm_backends": "TRITON", + "benchmark_epilogue_fusion": False, + } + ) def test_cat_max_autotune_triton(self): self._test_cat_max_autotune_impl(using_triton_mm=True) From 393fecb2cc43c03b54ade0c11078dd4e353d8b2f Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Mon, 21 Jul 2025 18:05:05 +0000 Subject: [PATCH 325/457] [Optimus][Unit test] clean up the unit test (#158696) Summary: We should only patch the specific pattern(s) for each unit test. Test Plan: ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion ``` Buck UI: https://www.internalfb.com/buck2/f8d37674-91c4-4244-90fa-f24fc3f91e4b Test UI: https://www.internalfb.com/intern/testinfra/testrun/2533275088644915 Network: Up: 100KiB Down: 233KiB (reSessionID-92039f44-bc6f-4e78-87b1-93bca1bd1c66) Analyzing targets. Remaining 0/296 Executing actions. Remaining 0/20196 5.8s exec time total Command: test. Finished 2 local, 2 cache (50% hit) 4.6s exec time cached (79%) Time elapsed: 3:55.1s Tests finished: Pass 13. Fail 0. Fatal 0. Skip 0. Build failure 0 Rollback Plan: Differential Revision: D78598127 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158696 Approved by: https://github.com/Skylion007, https://github.com/masnesral --- test/inductor/test_group_batch_fusion.py | 68 +++++++++++++++--------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 516120da29865..090a7e8e29d3f 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -286,24 +286,6 @@ def forward(self, x): return torch.stack((stack_input, stack_other), dim=0) -@requires_gpu() -@torch._inductor.config.patch( - pre_grad_fusion_options={ - "batch_linear": {}, - "batch_linear_lhs": {}, - "batch_layernorm": {}, - "batch_tanh": {}, - "batch_relu": {}, - "batch_sigmoid": {}, - }, - post_grad_fusion_options={ - "batch_aten_add": {}, - "batch_aten_mul": {}, - "batch_aten_sub": {}, - "batch_aten_div": {}, - "group_linear": {"require_fbgemm": True}, - }, -) class TestGroupBatchFusion(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): @@ -332,7 +314,14 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) + @requires_gpu() @unittest.skipIf(not has_fbgemm, "requires fbgemm") + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "group_linear": {"require_fbgemm": True}, + }, + ) def test_group_linear_fusion(self): z = 10 for has_bias in [True, False]: @@ -355,13 +344,16 @@ def test_group_linear_fusion(self): counters["inductor"]["group_linear"], 4, ) - self.assertEqual( - counters["inductor"]["batch_aten_add"], - 0, - ) counters.clear() + @requires_gpu() @unittest.skipIf(not has_fbgemm, "requires fbgemm") + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "group_linear": {"require_fbgemm": True}, + }, + ) def test_group_linear_fusion_different_shapes(self): counters.clear() module = MyModule2().eval().to(GPU_TYPE) @@ -386,13 +378,14 @@ def test_group_linear_fusion_different_shapes(self): counters["inductor"]["group_linear"], 2, ) - self.assertEqual( - counters["inductor"]["batch_aten_mul"], - 1, - ) counters.clear() + @requires_gpu() @unittest.skipIf(GPU_TYPE == "mps", "welford_reduce is yet not implemented for MPS") + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_layernorm": {}}, + post_grad_fusion_options={}, + ) def test_batch_layer_norm_fusion(self): for has_weight in [True, False]: for has_bias in [True, False]: @@ -410,6 +403,11 @@ def test_batch_layer_norm_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_linear_lhs": {}}, + post_grad_fusion_options={}, + ) def test_batch_linear_lhs_fusion(self): z = 10 for has_bias in [True, False]: @@ -427,6 +425,11 @@ def test_batch_linear_lhs_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_linear": {}}, + post_grad_fusion_options={}, + ) def test_batch_linear_pre_grad_fusion(self): for has_bias in [True, False]: counters.clear() @@ -443,6 +446,19 @@ def test_batch_linear_pre_grad_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "batch_relu": {}, + "batch_sigmoid": {}, + }, + post_grad_fusion_options={ + "batch_aten_add": {}, + "batch_aten_mul": {}, + "batch_aten_sub": {}, + "batch_aten_div": {}, + }, + ) def test_pointwise_op_fusion(self): counters.clear() module = TestPoitwiseOps(GPU_TYPE) From 8ed5e1844c77d952bcea89ca7d0225d876fec4e8 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 21 Jul 2025 15:18:55 +0000 Subject: [PATCH 326/457] [AOTI] Convert C-struct zip handling to RAII container (#158687) Attempts to fix a memory leak reported in #158614 by wrapping manually managed MiniZ C-structs in an RAII container. I have been unable to reproduce the reported leak, but this seems like the most likely candidate. Fixes #158614 (hopefully) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158687 Approved by: https://github.com/desertfire --- test/cpp/aoti_inference/test.cpp | 2 + .../aoti_package/model_package_loader.cpp | 112 +++++++++++------- 2 files changed, 69 insertions(+), 45 deletions(-) diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 59d575b2cc2bb..bff3827f8e8ac 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -144,6 +144,8 @@ void test_aoti_package_loader_multi_gpu( const std::string& device, bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; + // Ensure that this test will reset the default CUDA device on exit. + torch::DeviceGuard device_guard(c10::Device("cuda")); std::string data_path = (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt") diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 66568025718af..bc7ee87e10233 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -405,6 +405,69 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { } } +class RAIIMinizArchive { + public: + RAIIMinizArchive(const std::string& zip_path) { + mz_zip_zero_struct(&_zip_archive); + if (!mz_zip_reader_init_file(&_zip_archive, zip_path.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to initialize zip archive: {}", + mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)))); + } + } + RAIIMinizArchive(const RAIIMinizArchive&) = delete; + RAIIMinizArchive& operator=(const RAIIMinizArchive&) = delete; + RAIIMinizArchive(RAIIMinizArchive&&) noexcept = delete; + RAIIMinizArchive& operator=(RAIIMinizArchive&&) noexcept = delete; + ~RAIIMinizArchive() { + // Unconditionally close the file. We can't handle any errors here without + // terminating the program. + mz_zip_reader_end(&_zip_archive); + } + + std::vector get_filenames() { + const unsigned num_zip_files{mz_zip_reader_get_num_files(&_zip_archive)}; + std::vector zip_filenames{}; + zip_filenames.reserve(num_zip_files); + + for (unsigned i{0}; i < num_zip_files; ++i) { + // filename_buf_size == 0 returns the filename length, including null + // terminator + const auto zip_filename_len{ + mz_zip_reader_get_filename(&_zip_archive, i, nullptr, 0)}; + if (!zip_filename_len) { + throw std::runtime_error( + fmt::format("Failed to read zip filename length at index {}", i)); + } + // std::string implicitly appends a character for the null terminator + std::string zip_filename(zip_filename_len - 1, '\0'); + if (!mz_zip_reader_get_filename( + &_zip_archive, i, zip_filename.data(), zip_filename_len)) { + throw std::runtime_error( + fmt::format("Failed to read zip filename at index {}", i)); + } + zip_filenames.emplace_back(zip_filename); + } + + return zip_filenames; + } + + void extract_file( + const std::string& zip_filename, + const std::string& dest_filename) { + if (!mz_zip_reader_extract_file_to_file( + &_zip_archive, zip_filename.c_str(), dest_filename.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to extract zip file {} to destination file {}", + zip_filename, + dest_filename)); + } + } + + private: + mz_zip_archive _zip_archive{}; +}; + AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path, const std::string& model_name, @@ -424,34 +487,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extract all files within the zipfile to a temporary directory - mz_zip_archive zip_archive; - memset(&zip_archive, 0, sizeof(zip_archive)); - - if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { - throw std::runtime_error( - std::string("Failed to initialize zip archive: ") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - - std::vector found_filenames; - for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { - uint32_t zip_filename_len = - mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); - if (zip_filename_len == 0) { - throw std::runtime_error("Failed to read filename"); - } - // zip_filename_len returned by mz_zip_reader_get_filename includes the null - // terminator, so we need to subtract 1 here. - std::string zip_filename_str(zip_filename_len - 1, '\0'); - // zip_filename_str can't be normalize_path_separator, because it should be - // as index for mz_zip_reader_extract_file_to_file. - if (!mz_zip_reader_get_filename( - &zip_archive, i, zip_filename_str.data(), zip_filename_len)) { - throw std::runtime_error("Failed to read filename"); - } - found_filenames.push_back(zip_filename_str); - } - + RAIIMinizArchive zip_archive{model_package_path}; + auto found_filenames{zip_archive.get_filenames()}; if (found_filenames.empty()) { throw std::runtime_error("No files found in zip archive."); } @@ -486,7 +523,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // zip_filename_str can't be normalize_path_separator, because it should be // as index for mz_zip_reader_extract_file_to_file. - for (auto zip_filename_str : found_filenames) { + for (const auto& zip_filename_str : found_filenames) { auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory if (c10::starts_with(cur_filename, model_directory) || @@ -529,14 +566,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - mz_bool b_extract = mz_zip_reader_extract_file_to_file( - &zip_archive, zip_filename_str.c_str(), output_file_path.c_str(), 0); - if (b_extract == MZ_FALSE) { - throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", - zip_filename_str, - output_file_path)); - } + zip_archive.extract_file(zip_filename_str, output_file_path); // Save the file for bookkeeping size_t extension_idx = output_file_path.find_last_of('.'); @@ -553,14 +583,6 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } } - // Close the zip archive as we have extracted all files to the temp - // directory - if (!mz_zip_reader_end(&zip_archive)) { - throw std::runtime_error( - std::string("Failed to close zip archive: {}") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - if (cpp_filename.empty() && so_filename.empty()) { std::string found_filenames_str; for (const std::string& filename : found_filenames) { From 72db0a98a34a9f8982f7cf83145bf57b85e36817 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 18:54:04 +0000 Subject: [PATCH 327/457] Revert "[DTensor] Assert DTensorSpec has valid placements (#158133)" This reverts commit 1839e8d04b81ee6eda0cff6fbfc218a7a600f6f7. Reverted https://github.com/pytorch/pytorch/pull/158133 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. See D78496151 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158133#issuecomment-3097994857)) --- test/distributed/tensor/test_dtensor_compile.py | 2 +- torch/distributed/tensor/_dtensor_spec.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 5041a0d6de54d..86f1e9d8fb479 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -343,7 +343,7 @@ def test_dtensor_constructor_w_graph_break(self): x = torch.randn(64, 32, requires_grad=True) spec = DTensorSpec( mesh, - (Replicate(),), + (Replicate(), Shard(0)), tensor_meta=TensorMeta( shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype ), diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index c450720357ba8..eb528ee4f9af1 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -32,10 +32,6 @@ class DTensorSpec: def __post_init__(self) -> None: if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) - if not len(self.placements) == self.mesh.ndim: - raise ValueError( - f"DTensorSpec requires one placement per mesh dim (mesh.ndim={self.mesh.ndim}), got {self.placements=}" - ) self._hash: Optional[int] = None def __setattr__(self, attr: str, value: Any) -> None: From 662dd7db5b60dde71ef249041b9970ca36ef79e7 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Sun, 20 Jul 2025 17:38:37 -0700 Subject: [PATCH 328/457] [cutlass backend] cache maybe_append_choices (#156781) This PR attempts to cache: * codegen for cutlass backend for the same kernel. Even if runtime params are different. From some profiling, most of the time spent is on render. So we only target to cache that part for now. The output of render is `code`, and we are able to cache that easily. Also, I have to cache size_args, since it depends on `kernel.get_dynamic_shape_args()`, which depends on the state of self when we call render. make_key is doing most of the work here: We are hashing on input node layouts, output node layout and op.configuration_name() (this is what hash(op) would do anyway). Pull Request resolved: https://github.com/pytorch/pytorch/pull/156781 Approved by: https://github.com/ColinPeppler --- test/inductor/test_cutlass_backend.py | 178 ++++++++++++++++++ torch/_inductor/codegen/cuda/cuda_template.py | 102 +++++++--- torch/_inductor/config.py | 3 + 3 files changed, 259 insertions(+), 24 deletions(-) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index bb27b3d1a68c4..3b230865bcd9d 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -1555,6 +1555,178 @@ def mm(a, b): num_ops = int(match.group(1)) self.assertTrue(num_ops > 0, "The number of ops should be greater than 0") + @unittest.skipIf(not SM90OrLater, "need sm_90") + def test_maybe_append_choice_caching(self): + """ + Test if maybe_append_choice's caching leads to correct results and + shorter maybe_append_choice time. + """ + + NUM_ITERATIONS = 10 + + class TestModule(torch.nn.Module): + def forward(self, A, B): + for _ in range(NUM_ITERATIONS): + A = A @ B / 32 + return A + + model = TestModule().cuda() + A = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + B = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda").t() + + expected = model(A, B) + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + compiled_model = torch.compile(model, fullgraph=True) + actual = compiled_model(A, B) + + torch.testing.assert_close(actual, expected) + + # Check render call count: render is called uniquely for each codegen + # and for each finalized codegen. + self.assertEqual( + render_call_count, NUM_ITERATIONS + DEFAULT_INST_LEVEL_MM_CONFIG + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_multiple_mm(self): + """ + Test multiple matrix multiplications with different shapes in a single nn.Module. + """ + + class MultipleMMModel(torch.nn.Module): + def forward(self, a, b, c, d): + # First mm with shape (128, 64) @ (64, 32) -> (128, 32) + mm1 = a @ b + # Second mm with shape (256, 128) @ (128, 64) -> (256, 64) + mm2 = c @ d + return mm1, mm2 + + model = MultipleMMModel().cuda() + + # Create tensors with different shapes + a = torch.randn(128, 64).cuda().half() + b = torch.randn(32, 64).cuda().half().t() + c = torch.randn(256, 128).cuda().half() + d = torch.randn(64, 128).cuda().half().t() + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + # Get expected results + expected = model(a, b, c, d) + + # Compile and run + compiled_model = torch.compile(model) + actual = compiled_model(a, b, c, d) + + # Verify results + torch.testing.assert_close(actual, expected) + + num_matmuls = 2 + self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_multiple_mm_with_dynamic_shape(self): + """ + Test multiple matrix multiplications where one has dynamic shapes. + """ + + class MultipleMMDynamicModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.c = torch.randn(64, 256).cuda().half() + self.d = torch.randn(128, 256).cuda().half().t() + + def forward(self, a, b): + # dynamic shape matmul + mm1 = a @ b + # static shape matmul + mm2 = self.c @ self.d + return mm1, mm2 + + model = MultipleMMDynamicModel().cuda() + + # Create tensors with different shapes + a = torch.randn(128, 64).cuda().half() + b = torch.randn(32, 64).cuda().half().t() + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + # Get expected results + expected = model(a, b) + + # Compile and run + compiled_model = torch.compile(model, dynamic=True) + actual = compiled_model(a, b) + + # Verify results + torch.testing.assert_close(actual, expected) + + num_matmuls = 2 + self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_matmul_same_tensor(self): @@ -1849,8 +2021,14 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): """ full_ops = _gen_ops_cached(arch, cuda_version) ops = pytree.tree_flatten(full_ops)[0] + + # sanity check self.assertGreater(len(ops), 1000, "Too few ops generated") + # test if configuration name is unique + op_config_names = [op.configuration_name() for op in ops] + self.assertEqual(len(op_config_names), len(set(op_config_names))) + serializer = get_cutlass_operation_serializer() self.assertIsNotNone(serializer) diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 07ee9f127580f..2156369d56a58 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -10,7 +10,9 @@ import sympy import torch -from torch._inductor.utils import Placeholder +from torch._inductor import config +from torch._inductor.select_algorithm import create_inputs_key +from torch._inductor.utils import clear_on_fresh_cache, Placeholder from torch._logging import getArtifactLogger from ...autotune_process import CUDABenchmarkRequest, TensorMeta @@ -38,8 +40,12 @@ class ArgInfo: ty: str +@clear_on_fresh_cache class CUDATemplate(KernelTemplate): index_counter = itertools.count() + # dict of cache key to (code, size_args) + code_cache: dict[str, tuple[str, tuple[int, ...]]] = {} + cache_clear = staticmethod(code_cache.clear) def __init__( self, @@ -49,15 +55,15 @@ def __init__( input_reorder: Optional[list[int]] = None, ) -> None: """ - - Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + Baseclass for CUDA C++ Templates, derived from KernelTemplate. + Not to be instantiated directly. Args: name (str): The name of the CUDATemplate object. input_nodes (List[IRNode]): A list of input IRNodes. layout (Layout): The layout of the output buffer / tensor. - input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. - + input_reorder (Optional[List[int]]): An optional list that specifies + the order of the input nodes. """ super().__init__(name) self.input_nodes = input_nodes @@ -74,30 +80,51 @@ def _template_from_string(cls, source: str) -> Any: def supports_epilogue_fusion(op: GemmOperation) -> bool: return False - def generate( # type: ignore[override] - self, - description, - **kwargs, - ) -> CUDATemplateCaller: + def make_key(self, op: "GemmOperation") -> str: """ - Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller - may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + Make a key for the code cache. The idea of the method is to cache + everything that matters but doesn't include runtime param values, i.e., + self.get_runtime_arg_values(). Args: - kwargs: Additional keyword arguments. - - Returns: - A CUDATemplateCaller object representing the generated CUDA template caller. + kwargs: Additional keyword arguments. Including op (GemmOperation). + """ + return hashlib.sha256( + str( + ( + create_inputs_key(self.input_nodes), + self.input_reorder, + # output layout, same as self.output_node.get_layout() + self.layout, + self.get_runtime_arg_info(), + op.configuration_name(), + ) + ).encode("utf-8") + ).hexdigest() + + def generate_code_and_args(self, **kwargs) -> tuple[str, tuple[int, ...]]: """ + Generate code and args with caching. We cache the code even if runtime + args are different. + """ + key: Optional[str] = None + if config.cuda.enable_caching_codegen: + op = kwargs.get("op") + assert op is not None, "op is required for caching" + key = self.make_key(op) + + if key is not None and key in self.code_cache: + code, size_args = self.code_cache[key] + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + return code, extra_args + kernel_name = str(Placeholder.KERNEL_NAME) - with ( - patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), - CUDATemplateKernel( - kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), - runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel, - ): + kernel = CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() autotuning_log.debug("Generated Code:\n%s", code) @@ -122,8 +149,35 @@ def generate( # type: ignore[override] ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + + if key is not None: + self.code_cache[key] = code, size_args + + # extra args has runtime params, which shouldn't be cached extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + return code, extra_args + + def generate( # type: ignore[override] + self, + description: str, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. + This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel + in a standalone manner to enable Autotuning. + + Args: + description: op name followed by swizzle. + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + code, extra_args = self.generate_code_and_args(**kwargs) + + # not caching since kernel name is needed below kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] kernel_name = f"cutlass_{kernel_hash}" code = code.replace(self.name, kernel_name) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 25bd81ea7f8af..b18e6f45de55f 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1607,6 +1607,9 @@ class cuda: # Use this to overwrite and handle cache pollution binary_remote_cache_force_write: bool = False + # Enable caching codegen of cuda templates. + enable_caching_codegen: bool = True + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. From ad2dec1997077e3dc3e0ed8a26abce2261c04f86 Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Sun, 20 Jul 2025 16:39:36 -0700 Subject: [PATCH 329/457] [SymmMem] Add NVSHMEM alltoall support into Triton (#158513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements collective alltoall operation for NVSHMEM Triton kernels. Enables data exchange where each PE sends unique data to every other PE in the team. Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_alltoall`
Quick debug print for sanity check ```markdown ============================================================ [Rank 0] Starting alltoall test with world_size=2 ============================================================ [Rank 0] Configuration: - nelems_per_pe: 2 - dtype: torch.int64, element_size: 8 bytes - nelems_bytes: 16 /dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/modules/transport/ibrc/ibrc.cpp:1653: NULL value get_device_list failed /dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/modules/transport/ibrc/ibrc.cpp:1653: NULL value get_device_list failed [Rank 0] Preparing source data: [Rank 1] Preparing source data: - Data for PE 0: [0, 0] (indices 0-1) - Data for PE 1: [1, 1] (indices 2-3) [Rank 0] Complete source buffer: [0, 0, 1, 1] - Data for PE 0: [100, 100] (indices 0-1) - Data for PE 1: [101, 101] (indices 2-3) [Rank 1] Complete source buffer: [100, 100, 101, 101] [Rank 1] Initial destination buffer: [-1, -1, -1, -1] [Rank 0] Initial destination buffer: [-1, -1, -1, -1] /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [rank0]:[W716 15:30:06.215666766 ProcessGroupNCCL.cpp:5064] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device. [rank1]:[W716 15:30:06.215752786 ProcessGroupNCCL.cpp:5064] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device. NCCL version 2.27.5+cuda12.4 [Rank 1] Executing alltoall operation... [Rank 0] Executing alltoall operation... [Rank 1] alltoall operation completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 0] alltoall operation completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 0] Results after alltoall: [Rank 1] Results after alltoall:[Rank 0] Destination buffer: [0, 0, 100, 100] [Rank 0] Verifying results: - From PE 0 (indices 0-1): Expected: [0, 0] Actual: [0, 0] [Rank 1] Destination buffer: [1, 1, 101, 101] [Rank 1] Verifying results: - From PE 0 (indices 0-1): Expected: [1, 1] Actual: [1, 1] Match: ✓ Match: ✓ - From PE 1 (indices 2-3): Expected: [100, 100] - From PE 1 (indices 2-3): Expected: [101, 101] Actual: [100, 100] Actual: [101, 101] Match: ✓ Match: ✓ [Rank 0] ============================================================ [Rank 0] Summary: ALL TESTS PASSED ✓ [Rank 0] Data flow explanation: - Each rank sends 2 elements to every other rank [Rank 1] ============================================================ [Rank 1] Summary: ALL TESTS PASSED ✓ - Rank 0 sent: [0, 0, 1, 1] [Rank 1] Data flow explanation: - Each rank sends 2 elements to every other rank - Rank 0 received: [0, 0, 100, 100] - My data for PE 0 (0) went to PE 0's buffer - I received PE 0's data for me (0) - My data for PE 1 (1) went to PE 1's buffer - Rank 1 sent: [100, 100, 101, 101] - I received PE 1's data for me (100) [Rank 0] ============================================================ - Rank 1 received: [1, 1, 101, 101] - My data for PE 0 (100) went to PE 0's buffer - I received PE 0's data for me (1) - My data for PE 1 (101) went to PE 1's buffer - I received PE 1's data for me (101) [Rank 1] ============================================================ ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158513 Approved by: https://github.com/fduwjj, https://github.com/mandroid6 ghstack dependencies: #158511, #158512 --- test/distributed/test_nvshmem_triton.py | 58 +++++++++++++++++++ .../_symmetric_memory/_nvshmem_triton.py | 19 ++++++ 2 files changed, 77 insertions(+) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 6c7f38686a4c6..992c0895714ba 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -210,6 +210,16 @@ def sync_test_kernel( tl.store(p_dst, received + 1) +@triton.jit +def alltoall_kernel( + team_handle, + dest_ptr, + src_ptr, + nelems, +): + nvshmem.alltoall(team_handle, dest_ptr, src_ptr, nelems) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -831,6 +841,54 @@ def test_triton_sync(self) -> None: dst, torch.tensor([43], device=self.device, dtype=dtype) ) + @skipIfRocm + @requires_triton() + def test_triton_alltoall(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Each PE will send 2 int64 elements to every other PE + nelems_per_pe = 2 + dtype = torch.int64 + # Source buffer: contains data for all PEs + # Layout: [data_for_pe0, data_for_pe1, ...] + src_size = nelems_per_pe * world_size + src = symm_mem.empty(src_size, dtype=dtype, device=self.device) + # Fill source with rank-specific data + # Formula: rank * 100 + destination_pe + for i in range(world_size): + value = rank * 100 + i + src[i * nelems_per_pe : (i + 1) * nelems_per_pe] = value + # Destination buffer + dst = symm_mem.empty(src_size, dtype=dtype, device=self.device).fill_(-1) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Synchronize before alltoall + dist.barrier() + team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 + # Launch the kernel + alltoall_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nelems_per_pe, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Synchronize after alltoall + dist.barrier() + # Verify results + for i in range(world_size): + # After alltoall, we should receive data from PE i that was intended for us + # PE i sends (i * 100 + rank) to us + expected = i * 100 + rank + actual = dst[i * nelems_per_pe : (i + 1) * nelems_per_pe] + torch.testing.assert_close(actual, torch.full_like(actual, expected)) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index ed195a4225f13..08124483a9fe6 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -241,3 +241,22 @@ def sync_all(_builder=None): # type: ignore[no-untyped-def] is_pure=False, _builder=_builder, ) + + @core.extern + def alltoall(team, dest, source, nelems, _builder=None): # type: ignore[no-untyped-def] + """Perform alltoall operation on NVSHMEM symmetric memory""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nelems], + { + ( + core.dtype("int64"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # nelems + ): ("nvshmem_longlong_alltoall", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) From 22920c9138fb7a09db325038b70c8cf636b50653 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 21 Jul 2025 15:42:02 +0000 Subject: [PATCH 330/457] Grab bag of (mostly) typing improvements (#158075) Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are: 1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`. 2. Removing two unused parameters from `torch.export._unlift._unlift`. 3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency. 4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075 Approved by: https://github.com/Skylion007 --- benchmarks/dynamo/common.py | 22 ++++++-- test/inductor/test_aot_inductor_package.py | 8 ++- torch/_dynamo/utils.py | 13 ++++- torch/_inductor/__init__.py | 6 +-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 4 +- torch/_inductor/compile_fx.py | 14 ++--- torch/_inductor/package/package.py | 2 +- torch/export/__init__.py | 59 +++++++--------------- torch/export/_draft_export.py | 4 +- torch/export/_trace.py | 3 +- torch/export/_unlift.py | 43 +++++++--------- torch/export/experimental/__init__.py | 34 ++++++++----- torch/export/exported_program.py | 24 +++++---- 13 files changed, 124 insertions(+), 112 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 1088634ce911e..900a93c552b46 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -22,7 +22,7 @@ import time import weakref from contextlib import contextmanager -from typing import Any, NamedTuple, TYPE_CHECKING +from typing import Any, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar from unittest.mock import MagicMock import numpy as np @@ -54,6 +54,7 @@ from torch._inductor.utils import fresh_cache except ImportError: from _dynamo.utils import clone_inputs, graph_break_reasons + from _inductor.utils import fresh_cache import torch._functorch.config from torch._functorch.aot_autograd import set_model_name @@ -75,7 +76,10 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Sequence + +_D = TypeVar("_D", bound=dict[str, Any]) +_T = TypeVar("_T") log = logging.getLogger(__name__) @@ -766,7 +770,17 @@ def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor: return (time_total, result) if return_result else time_total -def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]: +@overload +def _normalize_bench_inputs(example_inputs: _D) -> tuple[tuple[()], _D]: ... + + +@overload +def _normalize_bench_inputs( + example_inputs: Sequence[_T], +) -> tuple[tuple[_T, ...], dict[str, Any]]: ... + + +def _normalize_bench_inputs(example_inputs): # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary, # and consumed like `model(**example_inputs)`. # For other benchmarks, example_inputs are formatted as tuple and consumed @@ -1671,7 +1685,7 @@ def __init__(self): self.grad_scaler = DummyGradScaler() self.autocast = contextlib.nullcontext self.autocast_arg = {} - self.optimizer = None + self.optimizer: Optional[torch.optim.Optimizer] = None self._args = None def setup_amp(self, current_device=None): diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 94fe620a9f18b..51343b6b1883e 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -16,12 +16,16 @@ import torch from torch._inductor.codecache import get_kernel_bin_format -from torch._inductor.package import AOTICompiledModel, load_package, package_aoti +from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase from torch._inductor.utils import fresh_cache from torch.export import Dim from torch.export.experimental import _ExportPackage -from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents +from torch.export.pt2_archive._package import ( + AOTICompiledModel, + load_pt2, + load_weights_to_pt2_contents, +) from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_utils import ( IS_FBCODE, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 725d46c06ae9b..d54a45f4156a1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -102,6 +102,7 @@ Iterable, Iterator, KeysView, + Sequence, ValuesView, ) @@ -2137,8 +2138,18 @@ def torch_clone(x): return result +@overload +def clone_inputs( + example_inputs: dict[str, Union[T, tuple[T, ...]]], +) -> dict[str, list[T]]: ... + + +@overload +def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ... + + def clone_inputs(example_inputs): - res: Union[dict[Any, Any], list[Any]] + res: Union[dict[str, Any], list[Any]] if type(example_inputs) is dict: res = dict(example_inputs) for key, value in res.items(): diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 94762a68b3435..f80b71cbe420d 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -6,7 +6,6 @@ import os from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union -import torch._inductor.config import torch.fx from .standalone_compile import CompiledArtifact # noqa: TC001 @@ -15,6 +14,7 @@ if TYPE_CHECKING: from torch._inductor.utils import InputType from torch.export import ExportedProgram + from torch.export.pt2_archive._package import AOTICompiledModel from torch.export.pt2_archive._package_weights import Weights from torch.types import FileLike @@ -223,7 +223,7 @@ def _aoti_compile_and_package_inner( not_strict_accuracy = check_accuracy == "accuracy" if not same_two_models( gm, - compiled_model, + compiled_model, # type: ignore[arg-type] args, only_fwd=True, require_fp64=not_strict_accuracy, @@ -238,7 +238,7 @@ def _aoti_compile_and_package_inner( def aoti_load_package( path: FileLike, run_single_threaded: bool = False, device_index: int = -1 -) -> Any: # type: ignore[type-arg] +) -> AOTICompiledModel: """ Loads the model from the PT2 package. diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index cbca6d9fe5d28..56d6f40dade81 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -630,10 +630,10 @@ def write_wrapper_decl(self): ), "Expect all constants to be Tensor" for idx, constants_key in enumerate(V.graph.constants.keys()): if V.graph.aot_mode: - # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Weights are stored in constants_ and owned by ConstantHandle there. # Don't call std::move here because it will cause constants_ to lose the ownership. self.prefix.writeline( - f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});""" + f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});""" ) else: # Append constants as inputs to the graph diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 4d02c353693f7..8e712a28a3b0f 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -53,6 +53,7 @@ ) from torch._functorch.aot_autograd import ( aot_export_module, + GraphOutputName, make_boxed_func, SerializableAOTDispatchCompiler, ) @@ -429,7 +430,7 @@ def _unlift_graph( from torch.export._unlift import _unlift - outputs = list(gm.graph.nodes)[-1].args[0] + outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] mutated_outputs = [] buffer_mutations = graph_signature.buffers_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate @@ -438,10 +439,11 @@ def _unlift_graph( value: Optional[Union[FQN, GraphInputName]] = None if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): - if out.name in buffer_mutations: - value = buffer_mutations[out.name] - elif out.name in user_input_mutations: - value = user_input_mutations[out.name] + name = GraphOutputName(out.name) + if name in buffer_mutations: + value = buffer_mutations[name] + elif name in user_input_mutations: + value = user_input_mutations[name] mutated_outputs.append(value) @@ -451,8 +453,6 @@ def _unlift_graph( mutated_outputs, pytree.LeafSpec(), None, - state_dict, - {}, ) return unlifted_gm diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index 726b41d972403..bd11d033cadb3 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -105,7 +105,7 @@ def load_package( run_single_threaded: bool = False, num_runners: int = 1, device_index: int = -1, -) -> AOTICompiledModel: # type: ignore[type-arg] +) -> AOTICompiledModel: try: pt2_contents = load_pt2( path, diff --git a/torch/export/__init__.py b/torch/export/__init__.py index c36056d22dd58..3ed8a6c37883f 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -1,59 +1,38 @@ -import builtins -import copy -import dataclasses -import inspect import os -import sys -import typing import warnings import zipfile -from collections.abc import Iterator -from enum import auto, Enum -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from collections.abc import Mapping +from typing import Any, Callable, Optional, Union from typing_extensions import deprecated import torch import torch.utils._pytree as pytree -from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult -from torch.fx.passes.infra.pass_manager import PassManager from torch.types import FileLike -from torch.utils._pytree import ( - FlattenFunc, - FromDumpableContextFn, - ToDumpableContextFn, - UnflattenFunc, -) - - -if TYPE_CHECKING: - # Import the following modules during type checking to enable code intelligence features, - # Do not import unconditionally, as they import sympy and importing sympy is very slow - from torch._ops import OpOverload - from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint __all__ = [ + "AdditionalInputs", "Constraint", - "Dim", - "ExportBackwardSignature", - "ExportGraphSignature", - "ExportedProgram", "CustomDecompTable", - "ModuleCallEntry", - "ModuleCallSignature", "default_decompositions", + "Dim", "dims", - "export", + "draft_export", "export_for_training", + "export", + "ExportBackwardSignature", + "ExportedProgram", + "ExportGraphSignature", + "FlatArgsAdapter", "load", + "ModuleCallEntry", + "ModuleCallSignature", "register_dataclass", "save", + "ShapesCollection", "unflatten", - "FlatArgsAdapter", "UnflattenedModule", - "AdditionalInputs", - "draft_export", ] # To make sure export specific custom ops are loaded @@ -82,9 +61,9 @@ def export_for_training( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), ) -> ExportedProgram: @@ -181,9 +160,9 @@ def export_for_training( def export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), ) -> ExportedProgram: @@ -540,9 +519,9 @@ def load( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), strict: bool = False, ) -> ExportedProgram: diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 9a9ed922c83e7..9d2179fcf252b 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -4,13 +4,13 @@ import os import re import tempfile +from collections.abc import Mapping from dataclasses import dataclass from enum import IntEnum from typing import Any, Callable, Optional, Union import torch import torch._logging._internal -import torch._logging.structured import torch.utils._pytree as pytree from torch._export.passes.insert_custom_op_guards import ( get_op_profiles, @@ -362,7 +362,7 @@ def _log_expression_created( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 4183fe22cda85..35be163b7e94a 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1723,6 +1723,7 @@ def _is_impure(node): gm.graph.eliminate_dead_code(_is_impure) # create graph signature + assert out_spec.spec is not None, "out_spec.spec is None!" input_names = _graph_input_names(gm) output_names = _graph_output_names(gm) sig = GraphSignature( @@ -1737,7 +1738,7 @@ def _is_impure(node): buffers_to_mutate={}, user_inputs_to_mutate={}, in_spec=in_spec, - out_spec=out_spec, # type: ignore[arg-type] + out_spec=out_spec.spec, backward_signature=None, input_tokens=[], output_tokens=[], diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 553d2eb2bf3b1..f7ae6cbf21ac7 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -138,12 +138,7 @@ def _insert_copy_for_mutations( Find the all the buffers and inputs that were mutated and insert copy_ operators to reflect mutations. """ - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - break - assert output_node is not None + output_node = gm.graph.output_node() outputs = pytree.tree_flatten(output_node.args)[0] assert len(outputs) == len(mutated_outputs) @@ -169,13 +164,13 @@ def _insert_copy_for_mutations( ) return_nodes_to_copy[return_node] = copy_node - output_args = [ + output_args = tuple( return_nodes_to_copy[node] if node in return_nodes_to_copy else node for node in user_output_nodes - ] + ) with gm.graph.inserting_before(output_node): # Only return user outputs - new_output = gm.graph.output(tuple(output_args)) + new_output = gm.graph.output(output_args) output_node.replace_all_uses_with(new_output) gm.graph.erase_node(output_node) new_output.name = output_node.name @@ -199,19 +194,18 @@ def _get_codegen( """ if forward_arg_names: names = forward_arg_names + elif ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) else: - if ( - in_spec.type == tuple - and in_spec.num_children == 2 - and in_spec.children_specs[0].type == tuple - and in_spec.children_specs[1].type == dict - ): - # if in_spec contains the args (tuple) and kwargs (dict) - names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] - # add kwarg names - names.extend(in_spec.children_specs[1].context) - else: - names = [f"arg_{i}" for i in range(in_spec.num_children)] + names = [f"arg_{i}" for i in range(in_spec.num_children)] return _PyTreeCodeGen( _PyTreeInfo( @@ -228,8 +222,6 @@ def _unlift( mutated_outputs: Sequence[Optional[str]], in_spec: pytree.TreeSpec, out_spec: Optional[pytree.TreeSpec], - state_dict: dict[str, Any], - constants: dict[str, Any], forward_arg_names: Optional[list[str]] = None, ): """ @@ -427,7 +419,7 @@ def _create_stateful_graph_module( return stateful_gm -def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule: # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) @@ -482,14 +474,13 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu ) ] + assert ep.call_spec.in_spec is not None new_gm = _unlift( new_gm, lifted_inputs, mutated_outputs, ep.call_spec.in_spec, ep.call_spec.out_spec, - ep.state_dict, - ep.constants, forward_arg_names=forward_arg_names, ) unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index b34bef61b508b..372eb3a29533d 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -2,7 +2,6 @@ import dataclasses import functools import os -import tempfile import types import typing import typing_extensions @@ -14,11 +13,15 @@ from torch.export.exported_program import _decompose_exported_program +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + __all__ = [] # type: ignore[var-annotated] def _copy_graph_module_and_signature( - ep: torch.fx.GraphModule, + ep: torch.export.ExportedProgram, ) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), # and this can break placeholder names in some particular cases. @@ -36,7 +39,7 @@ def _copy_graph_module_and_signature( for old_node, new_node in zip(old_phs, new_phs): new_node.name = old_node.name - return gm, new_graph_signature # type: ignore[return-value] + return gm, new_graph_signature def _remove_detach_pass( @@ -81,18 +84,27 @@ def _export_forward_backward( return ep._update(gm, new_graph_signature) -@typing.no_type_check -def _sticky_export(forward_func, dynamic_shapes_callback=None): +def _sticky_export( + forward_func: typing.Callable[_InputT, _RetT], + dynamic_shapes_callback: typing.Optional[ + typing.Callable[ + _InputT, + typing.Union[ + list[typing.Any], dict[str, typing.Any], tuple[typing.Any, ...] + ], + ] + ] = None, +) -> typing.Callable[_InputT, _RetT]: """ Lazily export the model on first forward call. Usage: model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) """ - model = forward_func.__self__ - original_forward = forward_func.__func__ + model = forward_func.__self__ # type: ignore[attr-defined] + original_forward = forward_func.__func__ # type: ignore[attr-defined] @functools.wraps(forward_func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: # Unpatch forward to avoid recursion during export model.forward = types.MethodType(original_forward, model) @@ -107,7 +119,7 @@ def wrapper(*args, **kwargs): kwargs, dynamic_shapes=dynamic_shapes_spec, ).module() - wrapper._exported_artifact = exported + wrapper._exported_artifact = exported # type: ignore[attr-defined] finally: # Restore the wrapper after export model.forward = wrapper @@ -123,10 +135,6 @@ class _ExportMethod: fallbacks: list[torch.export.ExportedProgram] -_InputT = typing_extensions.ParamSpec("_InputT") -_RetT = typing.TypeVar("_RetT") - - class _ExportPackage: """ An export package is a collection of torch.export()-ed PyTorch models consisting of diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index bbfb9202c560d..85900dd5e5ea0 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -7,10 +7,10 @@ import operator import types import warnings -from collections import defaultdict, namedtuple +from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union from torch._guards import tracing, TracingContext from torch._higher_order_ops.utils import autograd_not_implemented @@ -325,7 +325,7 @@ def default_decompositions() -> "CustomDecompTable": def _decompose_and_get_gm_with_new_signature_constants( - ep, + ep: "ExportedProgram", *, cia_to_decomp: dict[torch._ops.OperatorBase, Callable], python_decomp_table: dict[torch._ops.OperatorBase, Callable], @@ -384,9 +384,11 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # Fix the graph output signature to be tuple if scalar out_spec = mod._out_spec + assert isinstance(mod.graph._codegen, _PyTreeCodeGen) orig_arg_names = mod.graph._codegen.pytree_info.orig_args # aot_export expect the return type to always be a tuple. + assert out_spec is not None if out_spec.type not in (list, tuple): out_spec = pytree.TreeSpec(tuple, None, [out_spec]) @@ -610,7 +612,7 @@ def update_arg(old_arg, new_ph): raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] - new_outputs = list(gm.graph.nodes)[-1].args[0] + new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] # rename the placeholders assert len(new_placeholders) == len(old_placeholders) @@ -654,9 +656,9 @@ def update_arg(old_arg, new_ph): # update output specs gm.recompile() - for i, name in enumerate(_graph_output_names(gm)): - if isinstance(new_outputs[i], torch.fx.Node): - new_outputs[i].name = name + for output, name in zip(new_outputs, _graph_output_names(gm)): + if name is not None: + output.name = name # To match the output target with correct input for input mutations # need to find the old to new placeholder map @@ -727,7 +729,7 @@ def update_arg(old_arg, new_ph): for i, spec in enumerate(ep.graph_signature.input_specs) if isinstance(spec.arg, TensorArgument) } - for i, node in enumerate(new_outputs[len(output_specs) :]): + for node in new_outputs[len(output_specs) :]: source = gradients[node.name] spec = specs[source] # type: ignore[index] if spec.kind == InputKind.PARAMETER: @@ -1208,7 +1210,9 @@ def example_inputs(self, value): @property @compatibility(is_backward_compatible=False) def call_spec(self): - CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) + class CallSpec(NamedTuple): + in_spec: Optional[pytree.TreeSpec] + out_spec: Optional[pytree.TreeSpec] if len(self.module_call_graph) == 0: return CallSpec(in_spec=None, out_spec=None) @@ -1364,7 +1368,7 @@ def __str__(self) -> str: ) return string - def module(self) -> torch.nn.Module: + def module(self) -> torch.fx.GraphModule: """ Returns a self contained GraphModule with all the parameters/buffers inlined. """ From 25fbf09d5fc14b49a37eba9452db76985c8b4e38 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 21 Jul 2025 19:23:44 +0000 Subject: [PATCH 331/457] Use more fine-grained locks in sym mem kernels (#158523) Summary: Use only acq in the beginning of the kernel, and only release in the end Test Plan: Existing tests Rollback Plan: Differential Revision: D78458020 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158523 Approved by: https://github.com/drisspg, https://github.com/kwen2501 --- .../c10d/symm_mem/CUDASymmetricMemory-inl.h | 48 ++++++++----------- .../c10d/symm_mem/CUDASymmetricMemoryOps.cu | 26 +++++----- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h index bf5ea9a446bb1..0abbc84ebe52a 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h @@ -115,54 +115,44 @@ __device__ __forceinline__ void wait_signal(uint32_t* addr) { // Pattern 0: Ensures that all writes to symm_mem buffers from previous // kernels across all devices are visible to the current kernel: // -// sync_remote_blocks(...); +// sync_remote_blocks(...); // __syncthreads(); // // Pattern 1: Ensures that all writes to symm_mem buffers from the current // block are visible to all remote blocks with matching blockIdx: // // __syncthreads(); -// sync_remote_blocks(...); +// sync_remote_blocks(...); // __syncthreads(); // // Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe // for writing by subsequent kernels across all devices. // // __syncthreads(); -// sync_remote_blocks(...); -template +// sync_remote_blocks(...); +template __device__ __forceinline__ void sync_remote_blocks( - uint32_t** signal_pads, - size_t rank, - size_t world_size); - -template <> -__device__ __forceinline__ void sync_remote_blocks( - uint32_t** signal_pads, - size_t rank, - size_t world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - put_signal( - signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal( - signal_pads[rank] + blockIdx.x * world_size + target_rank); - } -} - -template <> -__device__ __forceinline__ void sync_remote_blocks( uint32_t** signal_pads, size_t rank, size_t world_size) { if (threadIdx.x < world_size) { auto target_rank = threadIdx.x; - put_signal( - signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal( - signal_pads[rank] + blockIdx.x * world_size + target_rank); + if constexpr (hasPrevMemAccess) { + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + } else { + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + } + if constexpr (hasSubsequentMemAccess) { + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } else { + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } } -} +}; template struct MultimemLdReduce { diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu index c4f38e468192d..3a004ae73ce74 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu @@ -134,7 +134,7 @@ static __global__ void multimem_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -152,7 +152,7 @@ static __global__ void multimem_all_reduce_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_all_reduce_( @@ -219,7 +219,7 @@ static __global__ void multimem_one_shot_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; @@ -230,7 +230,7 @@ static __global__ void multimem_one_shot_all_reduce_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_one_shot_all_reduce_out( @@ -311,7 +311,7 @@ static __global__ void multimem_all_gather_kernel( uint32_t** signal_pads, size_t rank, size_t world_size) { - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t start = bytes_per_rank * rank; @@ -324,7 +324,7 @@ static __global__ void multimem_all_gather_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_all_gather_out( @@ -425,7 +425,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } } // TODO make it sync with one block for no-copy case - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); for (size_t i = offset; i < numel; i += stride) { @@ -435,7 +435,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor one_shot_all_reduce_out_impl( @@ -587,7 +587,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ constexpr size_t numel_per_thread = alignment / sizeof(T); int32_t N_last_dim = last_dim_size / world_size; // used only for split_last_dim reduce_scatter - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -619,7 +619,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); if constexpr (reduce_scatter) { return; } @@ -654,7 +654,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ // need to make sure all blocks exit simultaneously so that the data // is not corrupted by the subsequent kernels __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } template @@ -669,7 +669,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -692,7 +692,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor two_shot_all_reduce_impl( From ea5b06ed5bc0e84f0d6a88a5b2b12ce71db78ac6 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 21 Jul 2025 19:34:10 +0000 Subject: [PATCH 332/457] [Dynamo][BetterEngineering] Type side_effects.py (#158605) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a core file, `side_effects.py` Running ``` mypy torch/_dynamo/side_effects.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 365 | 1166 | 31.30% | 16 | 51 | 31.37% | | This PR | 1185 | 1185 | 100.00% | 51 | 51 | 100.00% | | Delta | +820 | +19 | +68.70% | +35 | 0 | +68.63% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158605 Approved by: https://github.com/StrongerXi --- torch/_dynamo/side_effects.py | 217 ++++++++++++++++++------------ torch/_dynamo/symbolic_convert.py | 4 +- 2 files changed, 134 insertions(+), 87 deletions(-) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 8e3c4cd30145c..58ed0da5fb2de 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Side effect tracking and management for TorchDynamo's compilation system. @@ -28,11 +26,12 @@ import inspect import warnings import weakref -from collections.abc import MutableMapping +from collections.abc import Generator, MutableMapping from types import CellType from typing import Any, Optional, TYPE_CHECKING import torch.nn +from torch._dynamo.variables.misc import AutogradFunctionContextVariable from . import graph_break_hints, utils, variables from .bytecode_transformation import ( @@ -58,21 +57,25 @@ if TYPE_CHECKING: - from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.output_graph import OutputGraph + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + from torch._dynamo.variables.lists import ListVariable -def _manual_dict_setitem(dict_from, dict_to, mro_index): +def _manual_dict_setitem( + dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int +) -> None: # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have # to be careful because we don't want to trigger the user defined object # setitem or clear. The mro_index is used to find the dict/OrderedDict from # the class mro. dict_class = type(dict_to).__mro__[mro_index] - dict_class.clear(dict_to) + dict_class.clear(dict_to) # type: ignore[attr-defined] for k, v in dict_from.items(): - dict_class.__setitem__(dict_to, k, v) + dict_class.__setitem__(dict_to, k, v) # type: ignore[index] -def _manual_list_update(list_from, list_to): +def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None: list.clear(list_to) list.extend(list_to, list_from) @@ -103,13 +106,27 @@ class SideEffects: def __init__( self, - output_graph, - id_to_variable=None, - store_attr_mutations=None, - keepalive=None, - save_for_backward=None, - tensor_hooks=None, - ): + output_graph: "OutputGraph", + id_to_variable: Optional[dict[int, VariableTracker]] = None, + store_attr_mutations: Optional[ + dict[VariableTracker, dict[str, VariableTracker]] + ] = None, + keepalive: Optional[list[Any]] = None, + save_for_backward: Optional[ + list[tuple[AutogradFunctionContextVariable, list[VariableTracker]]] + ] = None, + tensor_hooks: Optional[ + dict[ + int, + tuple[ + "variables.TensorVariable", + VariableTracker, + "variables.RemovableHandleVariable", + str, + ], + ] + ] = None, + ) -> None: super().__init__() self.output_graph_weakref = weakref.ref(output_graph) self.id_to_variable = id_to_variable or {} @@ -122,7 +139,7 @@ def __init__( self._has_existing_dict_mutation = False # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. - self.ca_final_callbacks_var = None + self.ca_final_callbacks_var: Optional[ListVariable] = None # Tracks VariableTracker objects whose mutations can be skipped. # For normal mutated variables, Dynamo generates code to replay/reconstruct @@ -131,14 +148,14 @@ def __init__( # execution but don't need to be replayed in the generated code. # Used for temporary mutations in contexts like torch.func.functional_call, # where module parameters/buffers are modified but later restored. - self.ignore_mutation_on_these_variables = set() + self.ignore_mutation_on_these_variables: set[VariableTracker] = set() - def ignore_mutations_on(self, var): + def ignore_mutations_on(self, var: VariableTracker) -> None: """Mutations to this variable will be executed but not not tracked, typically used for temporary mutations that are later restored.""" self.ignore_mutation_on_these_variables.add(var) - def stop_ignoring_mutations_on(self, var): + def stop_ignoring_mutations_on(self, var: VariableTracker) -> None: """Remove a variable from the skip mutation set, restoring normal mutation tracking.""" if var in self.ignore_mutation_on_these_variables: self.ignore_mutation_on_these_variables.remove(var) @@ -175,10 +192,12 @@ def diff(self, other: "SideEffects") -> Optional[str]: else: return None - def clone(self): + def clone(self) -> "SideEffects": """Create a shallow copy""" + ref = self.output_graph_weakref() + assert ref is not None return self.__class__( - output_graph=self.output_graph_weakref(), + output_graph=ref, id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() @@ -188,36 +207,36 @@ def clone(self): tensor_hooks=self.tensor_hooks, ) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return id(item) in self.id_to_variable - def __getitem__(self, item): + def __getitem__(self, item: Any) -> VariableTracker: return self.id_to_variable[id(item)] - def should_allow_side_effects_under_checkpoint(self): + def should_allow_side_effects_under_checkpoint(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.under_activation_checkpoint and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) - def should_allow_externally_visible_side_effects_in_subtracer(self): + def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) - def is_reconstructing_generator(self): + def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.is_reconstructing_generator ) - def check_allowed_side_effect(self, item: VariableTracker): + def check_allowed_side_effect(self, item: VariableTracker) -> bool: from torch._dynamo.variables.misc import AutogradFunctionContextVariable # People do things like self.dim = dim inside autograd.Function. @@ -244,15 +263,24 @@ def check_allowed_side_effect(self, item: VariableTracker): explanation="This is not supported.", hints=[], ) + return False - def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): + def store_attr( + self, item: VariableTracker, name: str, value: VariableTracker + ) -> None: assert self.is_attribute_mutation(item) self.check_allowed_side_effect(item) if item not in self.store_attr_mutations: self.store_attr_mutations[item] = {} self.store_attr_mutations[item][name] = value - def load_attr(self, item, name, deleted_ok=False, check=False): + def load_attr( + self, + item: VariableTracker, + name: str, + deleted_ok: bool = False, + check: bool = False, + ) -> VariableTracker: if check: assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item][name] @@ -265,7 +293,7 @@ def load_attr(self, item, name, deleted_ok=False, check=False): ) return result - def store_cell(self, cellvar, value): + def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None: if cellvar.is_immutable(): unimplemented_v2( gb_type="Write to immutable cell", @@ -277,7 +305,7 @@ def store_cell(self, cellvar, value): assert isinstance(value, variables.VariableTracker) self.store_attr(cellvar, "cell_contents", value) - def load_cell(self, cellvar): + def load_cell(self, cellvar: VariableTracker) -> VariableTracker: assert isinstance(cellvar, variables.CellVariable) if self.has_pending_mutation_of_attr(cellvar, "cell_contents"): return self.load_attr(cellvar, "cell_contents", check=False) @@ -290,17 +318,19 @@ def load_cell(self, cellvar): hints=[*graph_break_hints.USER_ERROR], ) - def load_global(self, gvar: VariableTracker, name: str): + def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker: assert isinstance(gvar, variables.VariableTracker) return self.load_attr(gvar, name) - def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): + def store_global( + self, gvar: VariableTracker, name: str, value: VariableTracker + ) -> None: assert isinstance(gvar, variables.VariableTracker) assert isinstance(value, variables.VariableTracker) self.store_attr(gvar, name, value) @staticmethod - def cls_supports_mutation_side_effects(cls): + def cls_supports_mutation_side_effects(cls: type) -> bool: return inspect.getattr_static(cls, "__getattribute__", None) in ( object.__getattribute__, dict.__getattribute__, @@ -313,20 +343,20 @@ def cls_supports_mutation_side_effects(cls): BaseException.__getattribute__, ) - def is_attribute_mutation(self, item): + def is_attribute_mutation(self, item: VariableTracker) -> bool: return isinstance(item.mutation_type, AttributeMutation) - def has_pending_mutation(self, item): + def has_pending_mutation(self, item: VariableTracker) -> bool: return self.is_attribute_mutation(item) and bool( self.store_attr_mutations.get(item) ) - def has_pending_mutation_of_attr(self, item, name): + def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool: return self.is_attribute_mutation( item ) and name in self.store_attr_mutations.get(item, ()) - def is_modified(self, item): + def is_modified(self, item: VariableTracker) -> bool: if item.is_immutable(): return False if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)): @@ -341,14 +371,14 @@ def is_modified(self, item): if self.is_attribute_mutation(item): return item in self.store_attr_mutations - return item.mutation_type.is_modified + return item.mutation_type.is_modified # type: ignore[attr-defined] def _track_obj( self, item: Any, variable: VariableTracker, - mutation_type_cls=ValueMutationExisting, - ): + mutation_type_cls: type = ValueMutationExisting, + ) -> VariableTracker: """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( @@ -370,7 +400,7 @@ def track_object_existing( self, item: Any, variable: VariableTracker, - ): + ) -> VariableTracker: return self._track_obj( item, variable, @@ -382,8 +412,8 @@ def track_object_new( cls_source: Source, user_cls: Any, variable_cls: Any, - options, - ): + options: dict[str, Any], + ) -> VariableTracker: if user_cls is torch.autograd.function.FunctionCtx: with warnings.catch_warnings(record=True): obj = torch.autograd.Function() @@ -398,7 +428,7 @@ def track_object_new( self.keepalive.append(obj) return variable - def get_variable_cls(self, user_cls): + def get_variable_cls(self, user_cls: type) -> type: from torch.overrides import TorchFunctionMode from .variables.ctx_manager import GenericContextWrappingVariable @@ -439,11 +469,11 @@ def get_variable_cls(self, user_cls): def get_example_value( self, - base_cls_vt, - cls_vt, - init_args, - ): - user_cls = cls_vt.value + base_cls_vt: VariableTracker, + cls_vt: VariableTracker, + init_args: list[VariableTracker], + ) -> Any: + user_cls = cls_vt.value # type: ignore[attr-defined] if issubclass(user_cls, torch.nn.Module): # TODO(anijain2305) - Is it possible to remove this specialization? obj = nn_module_new(user_cls) @@ -470,10 +500,10 @@ def get_example_value( def track_new_user_defined_object( self, - base_cls_vt, - cls_vt, - init_args, - ): + base_cls_vt: VariableTracker, + cls_vt: VariableTracker, + init_args: list[VariableTracker], + ) -> VariableTracker: """ Creates a UserDefinedObjectVariable (or its subclass) variable tracker and mark it for attribute mutation tracking. @@ -483,7 +513,7 @@ def track_new_user_defined_object( base_cls_vt.__new__(user_cls, *init_args) """ cls_source = cls_vt.source - user_cls = cls_vt.value + user_cls = cls_vt.value # type: ignore[attr-defined] variable_cls = self.get_variable_cls(user_cls) obj = self.get_example_value(base_cls_vt, cls_vt, init_args) @@ -500,7 +530,7 @@ def track_new_user_defined_object( def track_cell_new( self, - ): + ) -> VariableTracker: obj = object() variable = variables.CellVariable( mutation_type=AttributeMutationNew(), @@ -511,7 +541,7 @@ def track_cell_new( def track_cell_existing( self, source: Optional[Source], cell: CellType, contents: VariableTracker - ): + ) -> VariableTracker: variable = variables.CellVariable( # We don't support mutation to cell without source because we need # source to properly codegen the mutations. @@ -523,7 +553,7 @@ def track_cell_existing( self.keepalive.append(cell) return variable - def track_global_existing(self, source: Source, item: Any): + def track_global_existing(self, source: Source, item: Any) -> VariableTracker: variable = variables.NewGlobalVariable( mutation_type=AttributeMutationExisting(), source=source, @@ -532,11 +562,15 @@ def track_global_existing(self, source: Source, item: Any): self.keepalive.append(item) return variable - def track_save_for_backward(self, ctx, args): + def track_save_for_backward( + self, ctx: VariableTracker, args: list[VariableTracker] + ) -> None: assert isinstance(ctx, variables.AutogradFunctionContextVariable) self.save_for_backward.append((ctx, args)) - def track_runahead_tensor_and_symvar_side_effects(self, other): + def track_runahead_tensor_and_symvar_side_effects( + self, other: "SideEffects" + ) -> None: # In higher order ops we want to keep track of tensors seen in the # speculate_subgraph so that we don't lift them again as a new input in # other speculate_subgraph or in the root tracer. @@ -548,12 +582,12 @@ def track_runahead_tensor_and_symvar_side_effects(self, other): ): self.track_object_existing(other_item, other_variable) - def prune_dead_object_new(self, tx): + def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None: # Avoid VT cycles from e.g., recursive function. visited: set[VariableTracker] = set() live_new_objects: set[VariableTracker] = set() - def visit(var: VariableTracker): + def visit(var: VariableTracker) -> None: if var in visited: return visited.add(var) @@ -569,7 +603,7 @@ def visit(var: VariableTracker): self.store_attr_mutations[var], ) - def is_live(var: VariableTracker): + def is_live(var: VariableTracker) -> bool: if isinstance(var.mutation_type, AttributeMutationNew): return var in live_new_objects return True @@ -612,7 +646,7 @@ def is_live(var: VariableTracker): k: v for k, v in self.store_attr_mutations.items() if is_live(k) } - def mutation(self, var): + def mutation(self, var: VariableTracker) -> None: if var in self.ignore_mutation_on_these_variables: return @@ -626,13 +660,13 @@ def mutation(self, var): ): self._has_existing_dict_mutation = True - def has_existing_dict_mutation(self): + def has_existing_dict_mutation(self) -> bool: return self._has_existing_dict_mutation - def _get_modified_vars(self): + def _get_modified_vars(self) -> list[VariableTracker]: return [var for var in self.id_to_variable.values() if self.is_modified(var)] - def codegen_save_tempvars(self, cg: PyCodegen): + def codegen_save_tempvars(self, cg: PyCodegen) -> None: # We must codegen modified VT to their source by default, so that # mutation and aliasing are properly accounted for. # @@ -692,7 +726,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): # base_cls.__new__(user_cls, *args) if isinstance(var, variables.UserDefinedObjectVariable): - def load_new_method(): + def load_new_method() -> None: assert var.base_cls_vt is not None cg(var.base_cls_vt) # type: ignore[attr-defined] cg.extend_output([cg.create_load_attr("__new__")]) @@ -706,11 +740,11 @@ def load_new_method(): cg(var.mutation_type.cls_source) # Generate the args to the __new__ method - for arg in var.init_args: + for arg in var.init_args: # type: ignore[attr-defined] cg(arg) # Call the __new__ method - cg.extend_output(create_call_function(1 + len(var.init_args), False)) + cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) @@ -727,7 +761,13 @@ def load_new_method(): ] ) - def register_hook(self, tensor, hook, handle, name): + def register_hook( + self, + tensor: "variables.TensorVariable", + hook: VariableTracker, + handle: "variables.RemovableHandleVariable", + name: str, + ) -> None: assert isinstance(tensor, variables.TensorVariable) assert isinstance(hook, variables.VariableTracker) assert ( @@ -743,10 +783,10 @@ def register_hook(self, tensor, hook, handle, name): assert not handle.idx handle.idx = idx - def remove_hook(self, idx): + def remove_hook(self, idx: int) -> None: del self.tensor_hooks[idx] - def codegen_hooks(self, cg): + def codegen_hooks(self, cg: PyCodegen) -> None: for ( tensor, hook, @@ -788,7 +828,7 @@ def codegen_hooks(self, cg): # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. assert tensor.source, "Hooks on non input tensors NYI - should not get here" - def gen_fn(): + def gen_fn() -> None: cg(tensor) cg.extend_output([cg.create_load_attr(name)]) @@ -800,16 +840,17 @@ def gen_fn(): # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) - def get_ca_final_callbacks_var(self): + def get_ca_final_callbacks_var(self) -> "variables.ListVariable": from .variables.base import ValueMutationNew if self.ca_final_callbacks_var is None: self.ca_final_callbacks_var = variables.ListVariable( [], mutation_type=ValueMutationNew() ) + return self.ca_final_callbacks_var - def codegen_update_mutated(self, cg: PyCodegen): + def codegen_update_mutated(self, cg: PyCodegen) -> None: suffixes = [] for var in self._get_modified_vars(): if isinstance(var, variables.ListVariable): @@ -1102,7 +1143,7 @@ def codegen_update_mutated(self, cg: PyCodegen): cg.pop_top() elif isinstance(var, variables.RandomVariable): # set correct random seed state - def gen_fn(): + def gen_fn() -> None: cg(var.source) # type: ignore[attr-defined] cg.load_attr("setstate") @@ -1122,7 +1163,7 @@ def gen_fn(): for suffix in reversed(suffixes): cg.extend_output(suffix) - def is_empty(self): + def is_empty(self) -> bool: return not ( any(map(self.is_modified, self.id_to_variable.values())) or self.tensor_hooks @@ -1130,13 +1171,15 @@ def is_empty(self): or self.tensor_hooks ) - def clear(self): + def clear(self) -> None: self.keepalive.clear() self.id_to_variable.clear() @contextlib.contextmanager -def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): +def allow_side_effects_under_checkpoint( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: assert tx.output.current_tracer.under_activation_checkpoint orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint try: @@ -1147,7 +1190,9 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): @contextlib.contextmanager -def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslator"): +def allow_externally_visible_side_effects_in_subtracer( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects try: tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True @@ -1157,7 +1202,9 @@ def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslato @contextlib.contextmanager -def disallow_side_effects_in_generator(tx: "InstructionTranslator"): +def disallow_side_effects_in_generator( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: orig_val = tx.output.current_tracer.is_reconstructing_generator try: tx.output.current_tracer.is_reconstructing_generator = True diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index b7ba37b08d35f..181b8ee9b042a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3444,7 +3444,7 @@ def __init__( side_effects.store_cell(cell_var, contents_var) else: cell_var = side_effects.track_cell_new() - cell_var.local_name = name + cell_var.local_name = name # type: ignore[attr-defined] self.symbolic_locals[name] = cell_var # Populate `symbolic_locals` with cells captured by this frame, @@ -3462,7 +3462,7 @@ def __init__( cell_var = side_effects.track_cell_existing( cell_source, cell, contents_var ) - cell_var.local_name = name + cell_var.local_name = name # type: ignore[attr-defined] self.symbolic_locals[name] = cell_var self.symbolic_torch_function_state = SymbolicTorchFunctionState( From b66f4298278c269bdca9a71883cacfa6e975a393 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Mon, 21 Jul 2025 20:17:23 +0000 Subject: [PATCH 333/457] Fix `torch.randint`, `torch.mul` param missing description (#158731) Wrong separator cause param description truncated. - Change separator of param and its description - Remove quote make `torch.dtype` display as reference to the class ## Test Result ### Before image image ### After image image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158731 Approved by: https://github.com/ngimel --- torch/_torch_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0766bf7742864..8c12d4d689307 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7605,7 +7605,7 @@ def merge_dicts(*dicts): Args: {input} - other (Tensor or Number) - the tensor or number to multiply input by. + other (Tensor or Number): the tensor or number to multiply input by. Keyword args: {out} @@ -8948,7 +8948,7 @@ def merge_dicts(*dicts): Keyword args: {generator} {out} - dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + dtype (torch.dtype, optional): the desired data type of returned tensor. Default: if ``None``, this function returns a tensor with dtype ``torch.int64``. {layout} {device} From 851e953f68a614921c4315467209b741206774af Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Mon, 21 Jul 2025 13:14:38 -0700 Subject: [PATCH 334/457] ci: Only run lint jobs on relevant files (#158773) Conditionally run lint jobs on relevant files, this is mainly targetd at clangtidy since it takes a long time but also includes mypy since that's an additional 4 minutes of runtime that we can save. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/158773 Approved by: https://github.com/malfet --- .github/workflows/lint.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1a21a68a865da..476195ab5eec7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,6 +35,21 @@ jobs: lintrunner-clang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to clangtidy / clangformat + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.h') || + contains(needs.get-changed-files.outputs.changed-files, '.cpp') || + contains(needs.get-changed-files.outputs.changed-files, '.cc') || + contains(needs.get-changed-files.outputs.changed-files, '.cxx') || + contains(needs.get-changed-files.outputs.changed-files, '.hpp') || + contains(needs.get-changed-files.outputs.changed-files, '.hxx') || + contains(needs.get-changed-files.outputs.changed-files, '.cu') || + contains(needs.get-changed-files.outputs.changed-files, '.cuh') || + contains(needs.get-changed-files.outputs.changed-files, '.mm') || + contains(needs.get-changed-files.outputs.changed-files, '.metal') + ) with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" @@ -59,6 +74,13 @@ jobs: lintrunner-mypy: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to mypy + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.py') || + contains(needs.get-changed-files.outputs.changed-files, '.pyi') + ) with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" From b1a0c34dd3b581f645842f06f9d0187e7a6562c8 Mon Sep 17 00:00:00 2001 From: Ruben Rodriguez Buchillon Date: Mon, 21 Jul 2025 20:41:03 +0000 Subject: [PATCH 335/457] [pt2 event logging] add configurable prefix (#157678) Summary: # Why make experiments easier to find # What - dynamo config to provide a prefix - use the prefix when sending data to scuba through the self.id_ field Test Plan: ``` # code edited to set the prefix as `coconutruben-02` buck2 run mode/opt scripts/coconutruben/torchmm:experiment 2>&1 | tee /tmp/epx040 ``` on scuba ``` | autotune_dtypes | autotune_offset | autotune_shape | autotune_strides | event | run_id | | -----| -----| -----| -----| -----| ----- | | "torch.float16, torch.float16" | "0, 0" | "4096x3008, 3008x2048" | "[3008, 1], [2048, 1]" | "mm_template_autotuning" | "coconutruben-02-e6bdccc5-6dcf-4d68-9a04-b34f2c6d94fd" | | "torch.float16, torch.float16" | "0, 0" | "4096x3008, 3008x2048" | "[3008, 1], [2048, 1]" | "mm_template_autotuning" | "coconutruben-02-14165153-5842-4eaa-9e6c-3b0cbc016375" | ``` Rollback Plan: Differential Revision: D77837550 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157678 Approved by: https://github.com/stashuk-olek --- torch/_dynamo/config.py | 3 +++ torch/_dynamo/utils.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index c7f0fb4adeb1f..5fd37d5392742 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -606,6 +606,9 @@ def default_debug_dir_root() -> str: os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1" ) +# Common prefix to append to the id of each compile run to filter out data +pt2_compile_id_prefix: Optional[str] = os.environ.get("PT2_COMPILE_ID_PREFIX", None) + # Run GC at the end of compilation run_gc_after_compile = Config( # type: ignore[var-annotated] default=True, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d54a45f4156a1..35f0522453a89 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1714,9 +1714,15 @@ def get_event_data(self) -> dict[str, Any]: def __init__(self): self.tls = threading.local() + + from . import config + # Generate a unique id for this logger, which we can use in scuba to filter down # to a single python run. - self.id_ = str(uuid.uuid4()) + if config.pt2_compile_id_prefix: + self.id_ = f"{config.pt2_compile_id_prefix}-{uuid.uuid4()}" + else: + self.id_ = str(uuid.uuid4()) # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) From bc379aebe2e69d306d1b05938a9e86c80f6b98cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 20:45:21 +0000 Subject: [PATCH 336/457] Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)" This reverts commit 8e57cdb746b4ab28865fdf01532f87b0d21700e9. Reverted https://github.com/pytorch/pytorch/pull/158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](https://github.com/pytorch/pytorch/pull/158048#issuecomment-3098746624)) --- test/dynamo/test_package.py | 34 ----------------------- torch/_dynamo/precompile_context.py | 9 ++---- torch/_inductor/compile_fx.py | 29 +------------------ torch/_inductor/runtime/autotune_cache.py | 10 ------- 4 files changed, 3 insertions(+), 79 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 51f6ca91136c9..3160007774090 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -15,7 +15,6 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._functorch import config as functorch_config -from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -429,39 +428,6 @@ def fn2(x): self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) - @parametrize("device", ("cuda", "xpu")) - @torch._dynamo.config.patch(caching_precompile=True) - def test_automatic_dynamo_autotune_cache(self, device): - if device == "cuda" and not HAS_CUDA: - raise unittest.SkipTest("Requires CUDA/Triton") - if device == "xpu" and not HAS_XPU: - raise unittest.SkipTest("Requires XPU/Triton") - - def fn(x, y): - return x.sin() + y - - arg1 = torch.randn(3, 3, device=device) - arg2 = torch.randn(3, 3, device=device) - expected = fn(arg1, arg2).clone() - - with PatchCaches(): - compiled_fn1 = torch.compile(fn, mode="max-autotune") - result = compiled_fn1(arg1, arg2).clone() - self.assertEqual(expected, result) - self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) - DynamoCache.clear() - - total_frames = torch._dynamo.convert_frame.FRAME_COUNTER - self._save_and_reload( - expected_backends=1, expected_dynamo=1, expected_autotune=1 - ) - compiled_fn1 = torch.compile(fn, mode="max-autotune") - with torch.compiler.set_stance("fail_on_recompile"): - result1 = compiled_fn1(arg1, arg2).clone() - self.assertEqual(expected, result1) - self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) - self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) - @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 040f54ce70db2..6bb42bb34bc35 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -70,8 +70,7 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact - - DynamoCodeStateArtifact - - AutotuneCacheArtifact (regular autotune results, same as Megacache) + - CodeStateArtifact (from torch._dynamo.package once available) """ # Protected by the compile_lock @@ -150,12 +149,8 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: artifacts_by_key = {} cache_info = CacheInfo() for artifact in chain(*artifacts.values()): - if artifact.type() == "autotune": - # Populate autotune cache artifacts - artifact.populate_cache() - else: - artifacts_by_key[artifact.key] = artifact cache_info.add(artifact) + artifacts_by_key[artifact.key] = artifact from torch._dynamo.package import _BackendId, DynamoCache diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 8e712a28a3b0f..95c12d12c7850 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -909,37 +909,10 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") - if torch._functorch.config.bundled_autograd_cache: - assert mb_compiled_graph is None - assert cache_info is None - # When using bundled autograd cache, we still want - # to use the TritonBundler, but we don't want to save - # the results here. The results will get saved directly - # to AOTAutogradCache. - TritonBundler.begin_compile() - try: - mb_compiled_graph = fx_codegen_and_compile( - gm, example_inputs, inputs_to_check, **graph_kwargs - ) - assert mb_compiled_graph is not None - ( - triton_bundle, - triton_bundler_meta, - ) = TritonBundler.collect() - mb_compiled_graph.set_triton_bundle(triton_bundle) - except (ShortenTraceback, SkipFrame): - raise - except Exception as e: - raise InductorError(e, currentframe()).with_traceback( - e.__traceback__ - ) from None - finally: - TritonBundler.end_compile() - # CACHE BYPASS: Compile the graph, don't save it to the cache # (this can happen either because cache was disabled, or we # determined the input is uncacheable) - elif cache_info is None or cache_info["cache_state"] == "bypass": + if cache_info is None or cache_info["cache_state"] == "bypass": assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 88b9c80c77146..01d038aab8e7b 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,7 +35,6 @@ from typing_extensions import override import torch -from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -126,7 +125,6 @@ def create( ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) - cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -302,10 +300,6 @@ def save( CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) - if torch._dynamo.config.caching_precompile: - PrecompileContext.record_artifact( - AutotuneCacheArtifact.type(), autotune_artifact_key, data - ) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -631,10 +625,6 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) - if torch._dynamo.config.caching_precompile: - PrecompileContext.record_artifact( - AutotuneCacheArtifact.type(), autotune_artifact_key, result - ) return result @override From 6b0526a2c47f517a28619a9aa0e8e0260e91ff46 Mon Sep 17 00:00:00 2001 From: Xuan Zhang Date: Mon, 21 Jul 2025 15:45:34 +0000 Subject: [PATCH 337/457] ban fusion of large amount of reads (#158667) This is an reland attempt of https://github.com/pytorch/pytorch/pull/157563, but insteading of introducing the `realize_acc_reads_size_threshold` config and setting to a default value, we set it to `None` for now to unblock an internal use case. Will deep dive into the issue and harden the logic in later PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158667 Approved by: https://github.com/yf225 --- test/inductor/test_memory.py | 53 ++++++++++++++++++++++++++++++++++++ torch/_inductor/choices.py | 11 ++++++++ torch/_inductor/config.py | 3 ++ torch/_inductor/graph.py | 21 ++++++++++++++ torch/_inductor/ir.py | 15 ++++++++++ torch/_inductor/memory.py | 13 +-------- torch/_inductor/scheduler.py | 26 ++++++------------ 7 files changed, 113 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index eaff539f7a493..3e23442b38ec7 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -8,6 +8,7 @@ from torch._inductor import config, memory from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_triton_code +from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -306,6 +307,58 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) + @serialTest() + def test_fusion_acc_large_reads(self): + def f(x, y, z): + res = torch.zeros_like(x[0]) + for i in range(4): + temp = torch.matmul(x, y) + z + res = res + temp + return res + + N = 128 + x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + + # CASE 1: no restriction on the amount of accumulation + with config.patch({"realize_acc_reads_size_threshold": float("inf")}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") + .run(code) + ) + + # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) + # at most 12 / 4 = 3 reads can be accumulated during fusion + with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") + .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") + .run(code) + ) + + # CASE 3: no such fusion allowed + with config.patch({"realize_acc_reads_size_threshold": N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf1, arg2_1,") + .check("triton_poi_fused_add_0.run(buf3, arg2_1,") + .check("triton_poi_fused_add_0.run(buf4, buf3,") + .check("triton_poi_fused_add_0.run(buf6, arg2_1,") + .check("triton_poi_fused_add_0.run(buf7, buf6,") + .check("triton_poi_fused_add_0.run(buf9, arg2_1,") + .check("triton_poi_fused_add_0.run(buf10, buf9,") + .run(code) + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b7bab02da5e4b..689a006eb56b8 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,6 +365,17 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False + if ( + config.realize_acc_reads_size_threshold is not None + and scheduler.fusion_accumulate_large_reads( + node1, + node2, + config.realize_acc_reads_size_threshold, + ) + ): + WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") + return False + return True @staticmethod diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b18e6f45de55f..4515e10604a78 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -574,6 +574,9 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 +realize_acc_reads_size_threshold: Optional[int] = ( + None # TODO(xuanzh): harden this to make it non optional +) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f28..ac299d5b0c2d0 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,6 +123,7 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -485,6 +486,9 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() + # Cache for dep size hints to avoid expensive recomputation + self.dep_size_hint_cache: dict[Dep, int] = {} + def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -570,6 +574,23 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) + def get_dep_size_hint(self, dep: Dep) -> int: + """ + Get the size hint for a dependency with caching to avoid expensive recomputation. + """ + if dep not in self.dep_size_hint_cache: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.dep_size_hint_cache[dep] = res + return self.dep_size_hint_cache[dep] + def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e0b3481473323..3ddfdc4be768c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7877,6 +7877,10 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): + """ + StorageBox allow in-place mutation of Tensors + """ + def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7926,10 +7930,21 @@ def realize_hint(self) -> None: ): self.realize() + def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool: + return ( + sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold + ) + def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() + or ( + config.realize_acc_reads_size_threshold is not None + and self.has_accumulated_enough_reads_by_size( + config.realize_acc_reads_size_threshold + ) + ) ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 5601bc4adcee4..d287208419a9f 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,19 +78,8 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ - # this function is copied from torch/_inductor/scheduler.py - # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - return res + return V.graph.get_dep_size_hint(dep) # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5c7a16d25bc64..f3986b897161a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,15 +2051,12 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ - __dep_size_hint_cache: dict[Dep, int] - def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() - self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3505,6 +3502,14 @@ def _find_single_user_inputs( return True return False + def fusion_accumulate_large_reads( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int + ) -> bool: + all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( + node1.read_writes.writes | node2.read_writes.writes + ) + return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold + def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4010,20 +4015,7 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - res = 0 - if dep not in self.__dep_size_hint_cache: - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.__dep_size_hint_cache[dep] = res - else: - res = self.__dep_size_hint_cache[dep] - return res + return V.graph.get_dep_size_hint(dep) def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode From 5e17932c226b79d8510fb8c76babbf898a68ae33 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Jul 2025 10:59:50 -0700 Subject: [PATCH 338/457] [DCP] Add support for ShardedTensor to PgTransport (#158573) Add support for ShardedTensors in when PGTransport is used for send/recv checkpoints Test is pulled from https://github.com/pytorch/pytorch/pull/157963 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158573 Approved by: https://github.com/meetv18 --- .../checkpoint/test_pg_transport.py | 115 ++++++++++++++++-- torch/distributed/checkpoint/_pg_transport.py | 82 ++++++++++++- 2 files changed, 184 insertions(+), 13 deletions(-) diff --git a/test/distributed/checkpoint/test_pg_transport.py b/test/distributed/checkpoint/test_pg_transport.py index df64e9451b467..baa2eb54b0548 100644 --- a/test/distributed/checkpoint/test_pg_transport.py +++ b/test/distributed/checkpoint/test_pg_transport.py @@ -1,13 +1,17 @@ # Owner(s): ["oncall: distributed"] import logging -import os from datetime import timedelta from typing import Optional from unittest.mock import MagicMock, patch import torch import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard as ShardedTensorShard, + ShardMetadata, +) from torch.distributed.checkpoint._pg_transport import ( _cast_tensor, _prepare_state_dict, @@ -34,9 +38,56 @@ logger = logging.getLogger(__name__) +def _create_sharded_tensor_state_dict( + rank: int, world_size: int, device: torch.device +) -> dict: + """ + Create state_dict with ShardedTensor for deterministic testing. + Args: + rank: Current rank + world_size: Total world size + device: Device to create tensors on + Returns: + dict: State dictionary with ShardedTensor + """ + # Create deterministic local shard for this rank + global_size = 64 + shard_size = global_size // world_size + start_idx = rank * shard_size + end_idx = (rank + 1) * shard_size + + # Create local tensor with deterministic values + local_tensor = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device=device + ).reshape(shard_size, 8) + + # Create ShardedTensor using init_from_local_shards + sharded_tensor = init_from_local_shards( + [ + ShardedTensorShard( + tensor=local_tensor, + metadata=ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{rank}/{device}", + ), + ) + ], + global_size, + 8, + ) + + return { + "sharded_tensor": sharded_tensor, + "rank_scalar": torch.tensor(float(rank), device=device), + } + + class SimpleModel(nn.Module): - def __init__(self): + def __init__(self, seed: int = 42): super().__init__() + # Set seed for deterministic initialization + torch.manual_seed(seed) self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 10) @@ -50,6 +101,7 @@ def ring_send_recv_checkpoint( ): """ Use the transport to send to rank + 1 and receive from rank - 1. + Each rank exchanges its own state_dict with the previous rank. """ next_rank = (rank + 1) % world_size prev_rank = (rank - 1) % world_size @@ -58,15 +110,11 @@ def ring_send_recv_checkpoint( received_checkpoint = transport.recv_checkpoint(prev_rank) else: received_checkpoint = transport.recv_checkpoint(prev_rank) - transport.send_checkpoint([next_rank], received_checkpoint) + transport.send_checkpoint([next_rank], state_dict) return received_checkpoint def _test_pg_transport(self, device) -> None: - # python test/distributed/checkpoint/test_pg_transport.py -k test_pg_transport - print(f"{self.rank=} pid: {os.getpid()} {device=}") - print("in test") - model = SimpleModel().to(device) transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) original_state_dict = model.state_dict() @@ -111,6 +159,48 @@ def _test_pg_transport_with_mixed_content(self, device) -> None: self.assertEqual(state_dict, received_checkpoint) +def _test_pg_transport_with_sharded_tensor(self, device) -> None: + # Set current CUDA device for NCCL + if device.type == "cuda": + torch.cuda.set_device(device) + + state_dict = _create_sharded_tensor_state_dict(self.rank, self.world_size, device) + transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) + print(state_dict) + received_checkpoint = ring_send_recv_checkpoint( + transport=transport, + state_dict=state_dict, + rank=self.rank, + world_size=self.world_size, + ) + print("finished comms") + print(received_checkpoint) + + # Validate that received checkpoint matches what we expect from rank - 1 + prev_rank = (self.rank - 1) % self.world_size + + # Compare rank_scalar (should be from previous rank) + # Note: PGTransport moves received tensors to CPU when no state_dict callback is provided + expected_rank_scalar = torch.tensor(float(prev_rank), device="cpu") + received_rank_scalar = received_checkpoint["rank_scalar"] # type: ignore[index] + print(f"{expected_rank_scalar=} {received_rank_scalar=}") + torch.testing.assert_close(expected_rank_scalar, received_rank_scalar) + + # For ShardedTensor, validate the local shard data matches what prev_rank would have + received_st = received_checkpoint["sharded_tensor"] # type: ignore[index] + global_size = 64 + shard_size = global_size // self.world_size + prev_start_idx = prev_rank * shard_size + prev_end_idx = (prev_rank + 1) * shard_size + expected_local_tensor = torch.arange( + prev_start_idx * 8, prev_end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + # Compare the actual tensor data + received_local_tensor = received_st.local_shards()[0].tensor + torch.testing.assert_close(expected_local_tensor, received_local_tensor) + + class PgTransportCPU(MultiProcContinousTest): world_size = 8 timeout: timedelta = timedelta(seconds=20) @@ -133,6 +223,9 @@ def test_pg_transport(self) -> None: def test_pg_transport_with_mixed_content(self) -> None: _test_pg_transport_with_mixed_content(self, self.device) + def test_pg_transport_with_sharded_tensor(self) -> None: + _test_pg_transport_with_sharded_tensor(self, self.device) + class PgTransportCUDA(MultiProcContinousTest): world_size = 2 @@ -160,6 +253,11 @@ def test_pg_transport(self) -> None: def test_pg_transport_with_mixed_content(self) -> None: _test_pg_transport_with_mixed_content(self, self.device) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_pg_transport_with_sharded_tensor(self) -> None: + _test_pg_transport_with_sharded_tensor(self, self.device) + class TestCastTensor(TestCase): def test_cast_tensor_different_dtypes(self): @@ -509,8 +607,5 @@ def test_send_checkpoint_with_cpu_tensors(self): self.assertGreaterEqual(self.mock_work.wait.call_count, 4) -# import fbvscode -# fbvscode.attach_debugger() - if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/_pg_transport.py b/torch/distributed/checkpoint/_pg_transport.py index cab908b5a8510..f4c53829b23b9 100644 --- a/torch/distributed/checkpoint/_pg_transport.py +++ b/torch/distributed/checkpoint/_pg_transport.py @@ -9,6 +9,12 @@ import torch from torch.distributed import ProcessGroup, Work +from torch.distributed._shard.sharded_tensor import ( + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) +from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata from torch.distributed.tensor import _DTensorSpec, DTensor from torch.utils._pytree import ( KeyPath, @@ -53,6 +59,22 @@ class _DTensorMeta: spec: _DTensorSpec +@dataclass +class _ShardedTensorMeta: + """ + This is the metadata for a ShardedTensor that is used to transfer checkpoints. + It contains the metadata for all local shards and the global tensor metadata. + + This must be pickleable so that it can be sent over the wire. + """ + + local_shards_meta: list[_TensorMeta] + local_shards_shard_metadata: list[ + ShardMetadata + ] # Original shard metadata for each local shard + sharded_tensor_metadata: ShardedTensorMetadata + + @dataclass class _StateDictMeta: """ @@ -72,7 +94,9 @@ class _StateDictMeta: treespec: TreeSpec paths: list[KeyPath] - non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] @contextmanager @@ -104,7 +128,9 @@ def _prepare_state_dict( leaves, treespec = tree_flatten_with_path(state_dict) paths: list[KeyPath] = [] - non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = [] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] = [] tensors: list[torch.Tensor] = [] for key_path, v in leaves: paths.append(key_path) @@ -120,6 +146,26 @@ def _prepare_state_dict( spec=v._spec, ) ) + elif isinstance(v, ShardedTensor): + # Handle ShardedTensor by extracting all local shards + local_shards = v.local_shards() + + # Prepare metadata for all local shards + local_shards_meta = [] + local_shards_shard_metadata = [] + for shard in local_shards: + tensor, tensor_meta = _prepare_tensor(shard.tensor) + tensors.append(tensor) + local_shards_meta.append(tensor_meta) + local_shards_shard_metadata.append(shard.metadata) + + non_tensor_leaves.append( + _ShardedTensorMeta( + local_shards_meta=local_shards_meta, + local_shards_shard_metadata=local_shards_shard_metadata, + sharded_tensor_metadata=v.metadata(), # Complete metadata + ) + ) elif isinstance(v, torch.Tensor): tensor, tensor_meta = _prepare_tensor(v) tensors.append(tensor) @@ -242,7 +288,6 @@ def recv_checkpoint(self, src_rank: int) -> object: Returns: The reconstructed state dictionary with model parameters """ - state_dict = self._state_dict() if self._state_dict else {} state_dict_leaves, _ = tree_flatten_with_path(state_dict) @@ -301,6 +346,37 @@ def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: elif isinstance(v, _DTensorMeta): tensor = recv(path, v.local) values.append(DTensor(tensor, v.spec, requires_grad=False)) + elif isinstance(v, _ShardedTensorMeta): + # Receive all local shards that were sent to us + local_shards = [] + current_rank = self._pg.rank() + + # Receive tensors for each local shard that was sent + for j, shard_meta in enumerate(v.local_shards_meta): + tensor = recv(path, shard_meta) + + # Use the original shard metadata that was stored during preparation + # but update the placement to reflect the current rank/device + original_shard_metadata = v.local_shards_shard_metadata[j] + updated_shard_metadata = ShardMetadata( + shard_offsets=original_shard_metadata.shard_offsets, + shard_sizes=original_shard_metadata.shard_sizes, + placement=f"rank:{current_rank}/{tensor.device.type}", + ) + + local_shard = ShardedTensorShard( + tensor=tensor, metadata=updated_shard_metadata + ) + local_shards.append(local_shard) + + # Use complete metadata to reconstruct ShardedTensor + sharded_tensor = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=v.sharded_tensor_metadata, + ) + ) + values.append(sharded_tensor) else: values.append(v) From 9e0473b56621162bd85e94943a516be4727e5651 Mon Sep 17 00:00:00 2001 From: zero000064 <123607335+zero000064@users.noreply.github.com> Date: Mon, 21 Jul 2025 21:11:06 +0000 Subject: [PATCH 339/457] removed zero dim cpu logic from fake_tensor.py (#147501) Fixes #144748 In #144748, the inconsistency between the eager mode and the inductor mode is reported as a bug. The root cause is fake_tenosr.py's find-common-device method, https://github.com/pytorch/pytorch/blob/0b0da81021e061c021e515bc35d7dc0dbbb05941/torch/_subclasses/fake_tensor.py#L833, takes zero dim cpu tensor into account but the device check in adaption.h doesn't. This fix is to add a list for some ops to bypass zero-dim-cpu-tensor check to align with the eager mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147501 Approved by: https://github.com/ezyang --- test/test_fake_tensor.py | 16 ++++++++++++++++ torch/_subclasses/fake_tensor.py | 13 +++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index e8c28cadbf829..7aa530ae3296b 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -211,6 +211,22 @@ def test_zero_dim(self): self.assertEqual(out.device, y.device) self.assertTrue(isinstance(out, FakeTensor)) + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_op_with_zero_dim_bypassed(self): + if torch._functorch.config.fake_tensor_propagate_real_tensors: + return + shape_env = ShapeEnv() + mode = FakeTensorMode(shape_env=shape_env) + x = torch.tensor(1.0, device="cuda") + y = torch.tensor(2.0) + fake_x = mode.from_tensor(x) + fake_y = mode.from_tensor(y) + + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation for.*" + ) as exc: + torch.nextafter(fake_x, fake_y) + def test_nan_to_num(self): with FakeTensorMode(): for dtype in [torch.float16, torch.float32]: diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index c17de15f46eac..77cf89e9186b6 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -889,6 +889,11 @@ def _find_common_device( aten._foreach_copy.default, ) + # list of ops not using zero dim cpu tensor logic to align with the eager mode. + bypass_zero_dim_cpu_tensor_check_ops = ordered_set( + aten.nextafter.default, + ) + def check_cpu_device(device: torch.device) -> bool: return device.type == "cpu" @@ -912,13 +917,17 @@ def merge_devices(t: object) -> None: is_cpu_zero_dim = t_is_cpu_zero_dim return + is_bypass_zero_dim_cpu_tensor_check_op = ( + func in bypass_zero_dim_cpu_tensor_check_ops + ) + # mismatching devices ! # if current tensor is cpu 0 dim, defer to existing device - if t_is_cpu_zero_dim: + if t_is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: return # current device is from cpu 0 dim tensor, overwrite - if is_cpu_zero_dim: + if is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: common_device = t.device is_cpu_zero_dim = t_is_cpu_zero_dim return From a991e285ae35159680b0ad4be24669906a6fa256 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 21 Jul 2025 09:23:02 -0700 Subject: [PATCH 340/457] [AOTI] Add more default options to compile_standalone (#158560) Summary: When compiling for standalone, make embed_kernel_binary and emit_multi_arch_kernel default to True, and add a default name for model_name_for_generated_files to make the generated cpp project easier to understand. Also improved the weights object file naming to be more readable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158560 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 10 ++- test/inductor/test_aot_inductor_package.py | 36 +++++++++ torch/_inductor/codecache.py | 21 ++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 ++--- torch/_inductor/codegen/triton.py | 5 ++ torch/_inductor/config.py | 8 +- torch/_inductor/cpp_builder.py | 93 +++++++++++++++++----- torch/_inductor/utils.py | 40 ++++++---- 8 files changed, 170 insertions(+), 61 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 9521f1defa0bd..49226013d81d2 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6646,11 +6646,19 @@ def test_compile_standalone_sets_package_cpp(self): result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True}) self.assertEqual(result["aot_inductor.package_cpp_only"], True) self.assertEqual(result["aot_inductor.compile_standalone"], True) + self.assertEqual(result["aot_inductor.embed_kernel_binary"], True) + self.assertEqual(result["aot_inductor.emit_multi_arch_kernel"], True) + self.assertEqual( + result["aot_inductor.model_name_for_generated_files"], "aoti_model" + ) - def test_compile_standalone_package_cpp_already_true(self): + def test_compile_standalone_explicit_set(self): patches = { "aot_inductor.compile_standalone": True, "aot_inductor.package_cpp_only": True, + "aot_inductor.embed_kernel_binary": True, + "aot_inductor.emit_multi_arch_kernel": True, + "aot_inductor.model_name_for_generated_files": "aoti_model", } result = maybe_aoti_standalone_config(patches) self.assertEqual(result, patches) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 51343b6b1883e..2809f5533bd9c 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -15,6 +15,7 @@ from parameterized import parameterized_class import torch +import torch._inductor.config from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase @@ -363,6 +364,7 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_after_package_static(self): # compile_standalone will set package_cpp_only=True self.check_package_cpp_only() @@ -419,12 +421,46 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") + @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @skipIfXpu # build system may be different + @torch._inductor.config.patch("test_configs.use_libtorch", True) + def test_compile_standalone_cos(self): + # compile_standalone will set package_cpp_only=True + self.check_package_cpp_only() + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return torch.cos(x) + + with torch.no_grad(): + example_inputs = (torch.randn(8, 32, device=self.device),) + model = Model().to(device=self.device) + + # Test compilation when model name is passed in + options = { + "aot_inductor.compile_standalone": True, + "aot_inductor.model_name_for_generated_files": "cos", + } + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + build_path, _ = self.cmake_compile( + model, example_inputs, options, tmp_dir + ) + # Check if the .a file was build successfully + a_path = build_path / "libcos.a" + self.assertTrue(a_path.exists()) + @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary + @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter(self): self.check_package_cpp_only() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c8b23aded15c2..dd5a591f421f3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1674,12 +1674,6 @@ def compile( wrapper_code = "\n".join((wrapper_code, kernel_code)) kernel_code = "" - from .utils import aoti_model_name_from_config - - model_class_name = "" - if config.aot_inductor.compile_standalone: - model_class_name = aoti_model_name_from_config() - wrapper_key, wrapper_path = write( wrapper_code, "wrapper.cpp", @@ -1712,6 +1706,8 @@ def compile( "model.h", ) ) as f: + # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone + model_class_name = config.aot_inductor.model_name_for_generated_files class_name = f"AOTInductorModel{model_class_name}" header_code = f.read() @@ -1726,7 +1722,7 @@ def compile( header_code, "h", specified_dir=specified_output_path, - key=f"{model_class_name}", + key=model_class_name, ) # Log the AOTInductor wrapper and kernel code, if needed. @@ -1840,7 +1836,7 @@ def format_consts_to_asm( consts_asm += f"\t.space {len(consts) - 8}\n" consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" - return consts_asm, "S" + return consts_asm, "weights.S" # Use c++ to convert consts to object file can support more compilers, such as msvc and icx. def format_consts_to_cpp( @@ -1865,7 +1861,7 @@ def format_consts_to_cpp( const_cpp += "\t\n" const_cpp += "};\t\n" const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" - return const_cpp, "cpp" + return const_cpp, "weights.cpp" if use_asm_build: consts_code, code_ext = format_consts_to_asm( @@ -1880,6 +1876,7 @@ def format_consts_to_cpp( consts_code, code_ext, specified_dir=str(specified_sub_dir), + key=config.aot_inductor.model_name_for_generated_files, ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( @@ -2173,7 +2170,13 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: asm_files = [] if not _IS_WINDOWS: ld, objcopy = get_ld_and_objcopy(use_relative_path) + kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): + if kernel_name not in kernels: + # It is possible that CudaKernelParamCache contains more Triton kernels + # than what the current graph uses + continue + if asm_file := value["asm"]: asm_files.append(asm_file) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 56d6f40dade81..9abdcce44f6c9 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,13 +22,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, ir -from ..utils import ( - _align, - aoti_model_name_from_config, - DeferredLineBase, - LineContext, - normalize_name, -) +from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides, IndentedBuffer, Kernel @@ -64,11 +58,15 @@ def __init__(self): self.device = "cpu" # must be initialized prior to calling super().__init__() self.included_devices: OrderedSet[str] = OrderedSet() - self.model_class_name_suffix = "" - if config.aot_inductor.compile_standalone: - self.model_class_name_suffix = aoti_model_name_from_config() + self.model_class_name_suffix = ( + config.aot_inductor.model_name_for_generated_files + if config.aot_inductor.compile_standalone + else "" + ) self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}" + super().__init__() + self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index a34665d720f47..4aaff61e77d4f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4479,6 +4479,11 @@ def define_kernel(self, src_code, node_schedule, kernel): kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] ) + if config.aot_inductor.model_name_for_generated_files: + # When AOTI compiles multiple submodules, we need to use the model name to + # distinguish kernel related symbols. + kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" + # use the original src_code as the key wrapper.src_to_kernel[src_code] = kernel_name subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 4515e10604a78..2ce07c6293233 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1450,12 +1450,12 @@ class aot_inductor: precompile_headers: bool = not is_fbcode() # Embed generated kernel binary files into model.so - embed_kernel_binary: bool = False + embed_kernel_binary: Optional[bool] = None # Generate kernel files that support multiple archs # For CUDA, this means generating fatbin files for kernels, and the fatbin files # contains PTX and SASS for the current architecture. - emit_multi_arch_kernel: bool = False + emit_multi_arch_kernel: Optional[bool] = None # If not None, the generated files with use this name in file stem. # If None, we will use a hash to name files. @@ -1842,6 +1842,10 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False + # If set to True, AOTI-generated CMakelists.txt will still use libtorch + # for unit testing + use_libtorch = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 47820d3d77250..64140542d9ba0 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -28,7 +28,6 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir -from torch._inductor.utils import aoti_model_name_from_config from torch.torch_version import TorchVersion @@ -1545,7 +1544,9 @@ def __init__( self._aot_mode: bool = False self._name = name - self._target_name = aoti_model_name_from_config() + self._target_name = ( + config.aot_inductor.model_name_for_generated_files or "aoti_model" + ) # Code start here, initial self internal variables firstly. self._build_option = BuildOption @@ -1771,9 +1772,13 @@ def save_compile_cmd_to_cmake( """ definitions = " ".join(self._build_option.get_definitions()) - target_library_type = ( - "STATIC" if config.aot_inductor.compile_standalone else "SHARED" - ) + if config.aot_inductor.compile_standalone: + if config.test_configs.use_libtorch: + add_target = f"add_library({self._target_name} STATIC)" + else: + add_target = f"add_executable({self._target_name} ${{CMAKE_CURRENT_SOURCE_DIR}}/main.cpp)" + else: + add_target = f"add_library({self._target_name} SHARED)" contents = textwrap.dedent( f""" @@ -1781,22 +1786,54 @@ def save_compile_cmd_to_cmake( project({self._target_name} LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) - # May need to point CMAKE_PREFIX_PATH to the right torch location - find_package(Torch REQUIRED) + # Set target + {add_target} - # Set a shared library target - add_library({self._target_name} {target_library_type}) + """ + ) - # Add macro definitions - target_compile_definitions({self._target_name} PRIVATE {definitions}) + if ( + not config.aot_inductor.compile_standalone + or config.test_configs.use_libtorch + ): + # When compile_standalone is True, the generated cpp project should + # not use Torch. But for unit testing purpose, we need to use Torch here. + contents += textwrap.dedent( + """ + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) - # Add compile flags - target_compile_options({self._target_name} PRIVATE {self._cflags_args}) - # Backend specific flags - target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + """ + ) + # flags and macros here are mostly CPU specific. Not emitting them for GPU models + # will make the generated CMake file more portable and won't really hurt performance. + # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may + # be still needed. + contents += textwrap.dedent( + f""" + # Add macro definitions + target_compile_definitions({self._target_name} PRIVATE {definitions}) + + # Add compile flags + target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + + # Backend-specific flags + target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) + else: + # When compile_standalone is True, use TorchStandalone instead of Torch + contents += textwrap.dedent( + """ + find_package(TorchStandalone REQUIRED) + # Set up include directories to find headers at the correct paths + target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS}) + target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS}/standalone) + + """ + ) - """ - ) if device_type == "cuda" and torch.version.hip is None: from torch._inductor.codecache import _nvcc_arch_as_compile_option @@ -1804,7 +1841,11 @@ def save_compile_cmd_to_cmake( contents += textwrap.dedent( f""" enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED) + target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) + target_compile_definitions({self._target_name} PRIVATE USE_CUDA) + target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -1833,7 +1874,7 @@ def save_compile_cmd_to_cmake( add_custom_command( OUTPUT ${{FATBIN_FILE}} COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} - -gencode arch=compute_80,code=compute_80 + -gencode arch=compute_{current_arch},code=compute_{current_arch} -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} ) @@ -1882,12 +1923,20 @@ def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> Non """ ) f.write(contents) - f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") - f.write( - f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" - ) + if asm_files: + f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") + f.write( + f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" + ) def save_link_cmd_to_cmake(self, cmake_path: str) -> None: + if ( + config.aot_inductor.compile_standalone + and not config.test_configs.use_libtorch + ): + # When compile_standalone is True, do not link with libtorch + return + lflags = " ".join(self._build_option.get_ldflags()) libs = " ".join(self._build_option.get_libraries()) contents = textwrap.dedent( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5f9ce0b814eba..9bec3fd764bf3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3309,20 +3309,34 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An Returns: dict[str, Any]: The possibly-updated `config_patches` dictionary. """ + + def patch_config( + config_patches: dict[str, Any], config_name: str, config_value: Any + ) -> None: + value = config_patches.get(config_name, getattr(config, config_name)) + if value is None: + config_patches[config_name] = config_value + elif not value: + raise RuntimeError( + f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True." + ) + compile_standalone = config_patches.get( "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone ) + # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing + config_patches = config_patches.copy() if compile_standalone: - package_cpp_only = config_patches.get( - "aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only + # Standlaone AOTInductor means only generate cpp project for building a standalone binary + patch_config(config_patches, "aot_inductor.package_cpp_only", True) + # Standlaone AOTInductor needs to embed the kernel code in the binary + patch_config(config_patches, "aot_inductor.embed_kernel_binary", True) + # Default to use multi-arch kernel codegen + patch_config(config_patches, "aot_inductor.emit_multi_arch_kernel", True) + patch_config( + config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model" ) - if package_cpp_only is None: - config_patches = {**config_patches, "aot_inductor.package_cpp_only": True} - elif not package_cpp_only: - raise RuntimeError( - "compile_standalone=True requires package_cpp_only=True. " - "Please set aot_inductor.package_cpp_only=True in your inductor config." - ) + return config_patches @@ -3351,11 +3365,3 @@ def is_valid_aoti_model_name() -> bool: ) return True - - -def aoti_model_name_from_config() -> str: - from torch._inductor import config - - model_name = config.aot_inductor.model_name_for_generated_files - model_name = "aoti_model" if model_name is None else model_name - return model_name From c774180e59409996fb123d6ff9261c2fc356c2f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 21 Jul 2025 21:35:33 +0000 Subject: [PATCH 341/457] Bump requests from 2.32.2 to 2.32.4 in /tools/build/bazel (#158006) Bumps [requests](https://github.com/psf/requests) from 2.32.2 to 2.32.4.
Release notes

Sourced from requests's releases.

v2.32.4

2.32.4 (2025-06-10)

Security

  • CVE-2024-47081 Fixed an issue where a maliciously crafted URL and trusted environment will retrieve credentials for the wrong hostname/machine from a netrc file. (#6965)

Improvements

  • Numerous documentation improvements

Deprecations

  • Added support for pypy 3.11 for Linux and macOS. (#6926)
  • Dropped support for pypy 3.9 following its end of support. (#6926)

v2.32.3

2.32.3 (2024-05-29)

Bugfixes

  • Fixed bug breaking the ability to specify custom SSLContexts in sub-classes of HTTPAdapter. (#6716)
  • Fixed issue where Requests started failing to run on Python versions compiled without the ssl module. (#6724)
Changelog

Sourced from requests's changelog.

2.32.4 (2025-06-10)

Security

  • CVE-2024-47081 Fixed an issue where a maliciously crafted URL and trusted environment will retrieve credentials for the wrong hostname/machine from a netrc file.

Improvements

  • Numerous documentation improvements

Deprecations

  • Added support for pypy 3.11 for Linux and macOS.
  • Dropped support for pypy 3.9 following its end of support.

2.32.3 (2024-05-29)

Bugfixes

  • Fixed bug breaking the ability to specify custom SSLContexts in sub-classes of HTTPAdapter. (#6716)
  • Fixed issue where Requests started failing to run on Python versions compiled without the ssl module. (#6724)
Commits
  • 021dc72 Polish up release tooling for last manual release
  • 821770e Bump version and add release notes for v2.32.4
  • 59f8aa2 Add netrc file search information to authentication documentation (#6876)
  • 5b4b64c Add more tests to prevent regression of CVE 2024 47081
  • 7bc4587 Add new test to check netrc auth leak (#6962)
  • 96ba401 Only use hostname to do netrc lookup instead of netloc
  • 7341690 Merge pull request #6951 from tswast/patch-1
  • 6716d7c remove links
  • a7e1c74 Update docs/conf.py
  • c799b81 docs: fix dead links to kenreitz.org
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=requests&package-manager=pip&previous-version=2.32.2&new-version=2.32.4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/pytorch/pytorch/network/alerts).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158006 Approved by: https://github.com/Skylion007 Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tools/build/bazel/requirements.in | 2 +- tools/build/bazel/requirements.txt | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/build/bazel/requirements.in b/tools/build/bazel/requirements.in index 37750163da81e..7498f9065f0c2 100644 --- a/tools/build/bazel/requirements.in +++ b/tools/build/bazel/requirements.in @@ -1,6 +1,6 @@ PyYAML==6.0.1 numpy==1.26.4 -requests==2.32.2 +requests==2.32.4 setuptools==78.1.1 sympy==1.12 typing_extensions==4.11.0 diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index a15924660167d..dab9792ceae31 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -203,9 +203,9 @@ pyyaml==6.0.1 \ --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f # via -r requirements.in -requests==2.32.2 \ - --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ - --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c +requests==2.32.4 \ + --hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \ + --hash=sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422 # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ From 216ba6e5f235bbfa0b025303ad4aa5ee473c5a8b Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Mon, 21 Jul 2025 21:44:44 +0000 Subject: [PATCH 342/457] Fix `MaskedTensor` to device ignored mask (#151205) Fixes #147140 ## Changes - Add `to` implementation in `MaskedTensor` to support move `mask` to target device ## Test Result ```python In [1]: import torch ...: from torch.masked import as_masked_tensor ...: data = torch.tensor([1,2,3]) ...: mask = torch.tensor([True,False,True]) ...: mt = as_masked_tensor(data, mask).to('cuda') ...: mt.get_data().device, mt.get_mask().device /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) /home/zong/code/pytorch/torch/masked/maskedtensor/_ops_refs.py:354: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(new_data, _maybe_get_mask(args[0])) Out[1]: (device(type='cuda', index=0), device(type='cuda', index=0)) In [2]: mt.sum(dim=0) /home/zong/code/pytorch/torch/masked/maskedtensor/core.py:247: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project. return MaskedTensor(data, mask) Out[2]: MaskedTensor(4, True) ``` ```bash pytest test/test_maskedtensor.py -vv ``` ![image](https://github.com/user-attachments/assets/640b809c-b4f0-4aca-a09e-04049017a745) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151205 Approved by: https://github.com/ezyang --- test/test_maskedtensor.py | 26 ++++++++++++++++++++++++++ torch/masked/maskedtensor/_ops_refs.py | 5 ++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index db1ffbc38c1f2..03c05c7ea6da4 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -236,6 +236,32 @@ def test_to_sparse(self, device): _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) + def test_to_device(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu") + mt_device = mt.to(new_device) + + self.assertEqual(mt_device.device.type, new_device.type) + self.assertEqual(mt_device.get_mask().device.type, new_device.type) + self.assertEqual(mt_device.get_data().device.type, new_device.type) + + def test_to_dtype(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32 + mt_dtype = mt.to(new_dtype) + + self.assertEqual(mt_dtype.dtype, new_dtype) + self.assertEqual(mt_dtype.get_mask().dtype, torch.bool) + self.assertEqual(mt_dtype.get_data().dtype, new_dtype) + def test_to_dense(self, device): samples = _generate_sample_data( device=device, diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 719df7eac464f..8135f149a1bfc 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) - return MaskedTensor(new_data, _maybe_get_mask(args[0])) + cloned_kwargs = kwargs.copy() + cloned_kwargs["dtype"] = torch.bool + new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs) + return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._softmax]) From 0e46f542861832153ae37d04da23e9fe8593a312 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 21 Jul 2025 22:09:36 +0000 Subject: [PATCH 343/457] [ROCm][CI] update HIP patch for 6.4.1 (#158651) patch is intended to fix hipGraph capture for some miopen kernels Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/158651 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/docker/common/install_rocm.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index d2d56ecec91df..02406ab71cdea 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -87,7 +87,7 @@ EOF if [[ $(ver $ROCM_VERSION) -ge $(ver 6.4) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4.1) ]]; then HIP_BRANCH=release/rocm-rel-6.4 - CLR_HASH=ca18eb3f77fa09292fcda62bc60c3e565d752ada # branch release/rocm-rel-6.4.1-statco-hotfix + CLR_HASH=606bc820b4b1f315d135da02a1f0b176ca50a92c # branch release/rocm-rel-6.4.1-statco-hotfix elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then HIP_BRANCH=release/rocm-rel-6.4 CLR_HASH=600f5b0d2baed94d5121e2174a9de0851b040b0c # branch release/rocm-rel-6.4-statco-hotfix From 9498d95b9c07bb140b7548aae8b19cb2f6b96fce Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 21 Jul 2025 22:12:54 +0000 Subject: [PATCH 344/457] [Dynamo][BetterEngineering] Type trace_rules.py (#158679) As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a core file, `trace_rules.py` Running ``` mypy torch/_dynamo/trace_rules.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 2564 | 3997 | 64.15% | 34 | 53 | 64.15% | | This PR | 4022 | 4022 | 100.00% | 53 | 53 | 100.00% | | Delta | +1458 | +25 | +35.85% | +19 | 0 | +35.85% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158679 Approved by: https://github.com/williamwen42 --- torch/_dynamo/trace_rules.py | 163 ++++++++++++++++++++--------------- torch/_dynamo/utils.py | 2 +- 2 files changed, 95 insertions(+), 70 deletions(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d4c98a0f6b151..0eadce79c05fc 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Tracing rules and policies for TorchDynamo compilation decisions. @@ -37,7 +35,6 @@ import sys import traceback import types -import typing import unittest from collections import defaultdict from pathlib import Path @@ -73,6 +70,7 @@ UserFunctionVariable, UserMethodVariable, ) +from .variables.base import VariableTracker np: Optional[types.ModuleType] = None @@ -82,10 +80,6 @@ pass -if typing.TYPE_CHECKING: - from .variables.base import VariableTracker - - """ A note on skip/inline rules: @@ -153,7 +147,14 @@ """ -manual_torch_name_rule_map: dict[str, Any] = { +manual_torch_name_rule_map: dict[ + str, + Union[ + type[TorchInGraphFunctionVariable], + type[SkipFunctionVariable], + type[UserFunctionVariable], + ], +] = { "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, @@ -2986,7 +2987,10 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: if ".py#" not in k: obj = load_object(k) else: - obj = _module_dir(torch) + k[len("torch/") :] + torch_dir = _module_dir(torch) + if torch_dir is None: + continue + obj = torch_dir + k[len("torch/") :] if obj is not None: if is_lru_cache_wrapped_function(obj): obj = obj.__wrapped__ @@ -2999,7 +3003,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: return d -def _load_obj_from_str(fully_qualified_name): +def _load_obj_from_str(fully_qualified_name: str) -> Any: module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) return getattr(importlib.import_module(module), obj_name) @@ -3009,7 +3013,7 @@ def _load_obj_from_str(fully_qualified_name): """ -def load_object(name): +def load_object(name: str) -> Any: try: x = name.split("#") if len(x) == 2: @@ -3030,7 +3034,7 @@ def load_object(name): @functools.cache -def get_tensor_method(): +def get_tensor_method() -> frozenset[Any]: disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} s = set() for name in dir(torch.Tensor): @@ -3059,7 +3063,7 @@ def get_tensor_method(): """ -def is_aten_op_or_tensor_method(obj): +def is_aten_op_or_tensor_method(obj: Any) -> bool: return obj in get_tensor_method() or isinstance( obj, (torch._ops.OpOverloadPacket, torch._ops.OpOverload), @@ -3095,16 +3099,16 @@ def __call__(self) -> set[int]: self.function_ids = value return self.function_ids - def get_name(self, idx: int, default: str): + def get_name(self, idx: int, default: str) -> str: self() # lazy init assert self.function_names is not None return self.function_names.get(idx, default) - def add(self, idx: int): + def add(self, idx: int) -> None: function_ids = self() # lazy init function_ids.add(idx) - def remove(self, idx: int): + def remove(self, idx: int) -> None: function_ids = self() if idx in function_ids: function_ids.remove(idx) @@ -3172,7 +3176,7 @@ def _numpy_function_ids() -> dict[int, str]: "sample", } - def is_supported(k, v, mod): + def is_supported(k: str, v: Any, mod: Any) -> bool: if not callable(v): return False if not getattr(v, "__module__", None): @@ -3231,53 +3235,53 @@ def _maybe_init_lazy_module(obj: object) -> None: fn() -def is_callable_allowed(obj) -> bool: +def is_callable_allowed(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _allowed_callable_ids -def is_nonstrict_trace_callable(obj) -> bool: +def is_nonstrict_trace_callable(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _nonstrict_trace_callable_ids -def is_callable_disallowed(obj) -> bool: +def is_callable_disallowed(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _disallowed_callable_ids -def is_forbidden(obj) -> bool: +def is_forbidden(obj: Any) -> bool: _maybe_init_lazy_module(obj) return inspect.getattr_static(obj, "_dynamo_forbidden", False) -def is_builtin_callable(obj) -> bool: +def is_builtin_callable(obj: Any) -> bool: # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids return id(obj) in _builtin_function_ids -def is_builtin_constant(obj) -> bool: +def is_builtin_constant(obj: Any) -> bool: return id(obj) in _builtin_constant_ids -def is_polyfilled_callable(obj) -> bool: +def is_polyfilled_callable(obj: Any) -> bool: # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids return id(obj) in _polyfilled_function_ids -def is_numpy(obj) -> bool: +def is_numpy(obj: Any) -> bool: if np is None: return False return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids -def is_numpy_dtype(obj) -> bool: +def is_numpy_dtype(obj: Any) -> bool: if np is None: return False return isinstance(obj, np.dtype) -def is_numpy_type_info(obj) -> bool: +def is_numpy_type_info(obj: Any) -> bool: if np is None: return False return isinstance(obj, (np.finfo, np.iinfo)) @@ -3315,7 +3319,7 @@ def is_numpy_type_info(obj) -> bool: ) -def _as_posix_path(path): +def _as_posix_path(path: str) -> str: posix_path = Path(os.path.normpath(path)).as_posix() # os.path.normpath and pathlib.Path remove trailing slash, so we need to add it back if path.endswith((os.path.sep, "/")): @@ -3323,13 +3327,13 @@ def _as_posix_path(path): return posix_path -def _strip_init_py(s): +def _strip_init_py(s: str) -> str: suffix = "__init__.py" s = s.removesuffix(suffix) return _as_posix_path(s) -def _module_dir(m: types.ModuleType): +def _module_dir(m: types.ModuleType) -> Optional[str]: # Protect against a module not exporting __file__ - this can happen for # frozen modules, for example. file = getattr(m, "__file__", None) @@ -3551,27 +3555,36 @@ def _module_dir(m: types.ModuleType): @functools.cache -def get_legacy_mod_inlinelist(): +def get_legacy_mod_inlinelist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() inlinelist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in LEGACY_MOD_INLINELIST } return inlinelist @functools.cache -def get_mod_inlinelist(): +def get_mod_inlinelist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() inlinelist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in MOD_INLINELIST } return inlinelist @functools.cache -def get_mod_skiplist(): +def get_mod_skiplist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() skiplist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in MOD_SKIPLIST } return skiplist @@ -3628,14 +3641,14 @@ def get_mod_skiplist(): FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} -def _recompile_re(): +def _recompile_re() -> None: global SKIP_DIRS_RE SKIP_DIRS_RE = re.compile( rf"^[^\s<]*({'|'.join(re.escape(_as_posix_path(d)) for d in SKIP_DIRS)})" ) -def add(import_name: str): +def add(import_name: str) -> None: if isinstance(import_name, types.ModuleType): return add(import_name.__name__) assert isinstance(import_name, str) @@ -3657,7 +3670,7 @@ class SkipResult: reason: Optional[str] -def check_file(filename, is_inlined_call=False): +def check_file(filename: Optional[str], is_inlined_call: bool = False) -> SkipResult: """Should skip this file?""" if filename is None: return SkipResult(True, "filename is None") @@ -3695,8 +3708,10 @@ def check_file(filename, is_inlined_call=False): ): return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + unittest_dir = _module_dir(unittest) if ( - filename.startswith(_module_dir(unittest)) + unittest_dir is not None + and filename.startswith(unittest_dir) and not torch._dynamo.config.enable_trace_unittest ): return SkipResult(True, "unittest") @@ -3751,7 +3766,7 @@ def f3(x, y): """ -def check_verbose(obj, is_inlined_call=False): +def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: if isinstance( obj, ( @@ -3770,18 +3785,23 @@ def check_verbose(obj, is_inlined_call=False): elif isinstance(obj, types.CodeType): fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) elif isinstance(obj, (types.FunctionType, types.MethodType)): + filename = getfile(obj) + assert filename is not None fi = FunctionInfo( obj, obj.__name__, - getfile(obj), + filename, obj.__code__, # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed ) else: - fi = FunctionInfo(obj, None, getfile(obj), None) + filename = getfile(obj) + assert filename is not None + fi = FunctionInfo(obj, None, filename, None) # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) + assert rule is not None if issubclass( rule, ( @@ -3807,7 +3827,7 @@ def check_verbose(obj, is_inlined_call=False): ) -def check(obj, is_inlined_call=False): +def check(obj: Any, is_inlined_call: bool = False) -> bool: return check_verbose(obj, is_inlined_call).skipped @@ -3818,21 +3838,23 @@ def check(obj, is_inlined_call=False): _recompile_re() -def is_torch_inline_allowed(filename): +def is_torch_inline_allowed(filename: str) -> bool: return any(filename.startswith(d) for d in get_mod_inlinelist()) @functools.cache -def dynamo_dir(): +def dynamo_dir() -> Optional[str]: import torch._dynamo return _module_dir(torch._dynamo) -def is_torch(filename): - if filename.startswith(dynamo_dir()): +def is_torch(filename: str) -> bool: + dynamo_path = dynamo_dir() + if dynamo_path is not None and filename.startswith(dynamo_path): return False - return filename.startswith(_module_dir(torch)) + torch_path = _module_dir(torch) + return torch_path is not None and filename.startswith(torch_path) """ @@ -3840,7 +3862,7 @@ def is_torch(filename): """ -def lookup_callable(obj): +def lookup_callable(obj: Callable[..., Any]) -> Optional[type[VariableTracker]]: if not hashable(obj): return None # Custom allow/disallow in graph takes precedence over the general lookup. @@ -3861,18 +3883,18 @@ def lookup_callable(obj): """ -def lookup(obj): +def lookup(obj: Any) -> Optional[type[VariableTracker]]: return lookup_inner(obj) # also takes config.dont_skip_tracing into account def lookup_inner( - obj, - name=None, - filename=None, - is_direct_call=True, + obj: Any, + name: Optional[str] = None, + filename: Optional[str] = None, + is_direct_call: bool = True, reasons: Union[None, set[str]] = None, -): +) -> Optional[type[VariableTracker]]: result = _lookup_inner( obj, name=name, @@ -3887,12 +3909,15 @@ def lookup_inner( if config.dont_skip_tracing and result is SkipFunctionVariable: if filename is None: filename = getfile(obj) + assert filename is not None filename = _as_posix_path(filename) - dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo" - if filename.startswith(dynamo_path) and not filename.endswith( - "test_dont_skip_tracing_functions.py" - ): - return SkipFunctionVariable + torch_dir = _module_dir(torch) + if torch_dir is not None: + dynamo_path = _as_posix_path(torch_dir) + "_dynamo" + if filename.startswith(dynamo_path) and not filename.endswith( + "test_dont_skip_tracing_functions.py" + ): + return SkipFunctionVariable if reasons is not None: reasons.add( "Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing" @@ -3902,12 +3927,12 @@ def lookup_inner( def _lookup_inner( - obj, - name=None, - filename=None, - is_direct_call=True, - reasons: Union[None, set[str]] = None, -): + obj: Any, + name: Optional[str] = None, + filename: Optional[str] = None, + is_direct_call: bool = True, + reasons: Optional[set[str]] = None, +) -> Optional[type[VariableTracker]]: # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. # The rules defined in `torch_name_rule_map` mainly includes two parts: # - Manually defined rules for any functions. @@ -3981,7 +4006,7 @@ def _lookup_inner( filename = getfile(obj) skip_result = check_file(filename, is_direct_call) - if reasons is not None: + if reasons is not None and skip_result.reason is not None: reasons.add(skip_result.reason) if skip_result.skipped: return SkipFunctionVariable @@ -3989,7 +4014,7 @@ def _lookup_inner( return UserFunctionVariable -def clear_lru_cache(): +def clear_lru_cache() -> None: torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() torch._dynamo.trace_rules.get_tensor_method.cache_clear() torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 35f0522453a89..f850e3ecb7c31 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2236,7 +2236,7 @@ def torchscript(model, example_inputs, verbose=False): return None -def getfile(obj): +def getfile(obj: Any) -> Optional[str]: try: return inspect.getfile(obj) except (TypeError, OSError): From 97d7dc197f9eec99a3e2a163a4fa78f97e6b75a8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 22:13:24 +0000 Subject: [PATCH 345/457] Revert "[AOTI] Convert C-struct zip handling to RAII container (#158687)" This reverts commit 8ed5e1844c77d952bcea89ca7d0225d876fec4e8. Reverted https://github.com/pytorch/pytorch/pull/158687 on behalf of https://github.com/ZainRizvi due to Sorry but I had to revert this PR in order to revert https://github.com/pytorch/pytorch/pull/158671 ([comment](https://github.com/pytorch/pytorch/pull/158687#issuecomment-3099515618)) --- test/cpp/aoti_inference/test.cpp | 2 - .../aoti_package/model_package_loader.cpp | 112 +++++++----------- 2 files changed, 45 insertions(+), 69 deletions(-) diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index bff3827f8e8ac..59d575b2cc2bb 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -144,8 +144,6 @@ void test_aoti_package_loader_multi_gpu( const std::string& device, bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; - // Ensure that this test will reset the default CUDA device on exit. - torch::DeviceGuard device_guard(c10::Device("cuda")); std::string data_path = (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt") diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index bc7ee87e10233..66568025718af 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -405,69 +405,6 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { } } -class RAIIMinizArchive { - public: - RAIIMinizArchive(const std::string& zip_path) { - mz_zip_zero_struct(&_zip_archive); - if (!mz_zip_reader_init_file(&_zip_archive, zip_path.c_str(), 0)) { - throw std::runtime_error(fmt::format( - "Failed to initialize zip archive: {}", - mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)))); - } - } - RAIIMinizArchive(const RAIIMinizArchive&) = delete; - RAIIMinizArchive& operator=(const RAIIMinizArchive&) = delete; - RAIIMinizArchive(RAIIMinizArchive&&) noexcept = delete; - RAIIMinizArchive& operator=(RAIIMinizArchive&&) noexcept = delete; - ~RAIIMinizArchive() { - // Unconditionally close the file. We can't handle any errors here without - // terminating the program. - mz_zip_reader_end(&_zip_archive); - } - - std::vector get_filenames() { - const unsigned num_zip_files{mz_zip_reader_get_num_files(&_zip_archive)}; - std::vector zip_filenames{}; - zip_filenames.reserve(num_zip_files); - - for (unsigned i{0}; i < num_zip_files; ++i) { - // filename_buf_size == 0 returns the filename length, including null - // terminator - const auto zip_filename_len{ - mz_zip_reader_get_filename(&_zip_archive, i, nullptr, 0)}; - if (!zip_filename_len) { - throw std::runtime_error( - fmt::format("Failed to read zip filename length at index {}", i)); - } - // std::string implicitly appends a character for the null terminator - std::string zip_filename(zip_filename_len - 1, '\0'); - if (!mz_zip_reader_get_filename( - &_zip_archive, i, zip_filename.data(), zip_filename_len)) { - throw std::runtime_error( - fmt::format("Failed to read zip filename at index {}", i)); - } - zip_filenames.emplace_back(zip_filename); - } - - return zip_filenames; - } - - void extract_file( - const std::string& zip_filename, - const std::string& dest_filename) { - if (!mz_zip_reader_extract_file_to_file( - &_zip_archive, zip_filename.c_str(), dest_filename.c_str(), 0)) { - throw std::runtime_error(fmt::format( - "Failed to extract zip file {} to destination file {}", - zip_filename, - dest_filename)); - } - } - - private: - mz_zip_archive _zip_archive{}; -}; - AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path, const std::string& model_name, @@ -487,8 +424,34 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extract all files within the zipfile to a temporary directory - RAIIMinizArchive zip_archive{model_package_path}; - auto found_filenames{zip_archive.get_filenames()}; + mz_zip_archive zip_archive; + memset(&zip_archive, 0, sizeof(zip_archive)); + + if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { + throw std::runtime_error( + std::string("Failed to initialize zip archive: ") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + std::vector found_filenames; + for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { + uint32_t zip_filename_len = + mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); + if (zip_filename_len == 0) { + throw std::runtime_error("Failed to read filename"); + } + // zip_filename_len returned by mz_zip_reader_get_filename includes the null + // terminator, so we need to subtract 1 here. + std::string zip_filename_str(zip_filename_len - 1, '\0'); + // zip_filename_str can't be normalize_path_separator, because it should be + // as index for mz_zip_reader_extract_file_to_file. + if (!mz_zip_reader_get_filename( + &zip_archive, i, zip_filename_str.data(), zip_filename_len)) { + throw std::runtime_error("Failed to read filename"); + } + found_filenames.push_back(zip_filename_str); + } + if (found_filenames.empty()) { throw std::runtime_error("No files found in zip archive."); } @@ -523,7 +486,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // zip_filename_str can't be normalize_path_separator, because it should be // as index for mz_zip_reader_extract_file_to_file. - for (const auto& zip_filename_str : found_filenames) { + for (auto zip_filename_str : found_filenames) { auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory if (c10::starts_with(cur_filename, model_directory) || @@ -566,7 +529,14 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - zip_archive.extract_file(zip_filename_str, output_file_path); + mz_bool b_extract = mz_zip_reader_extract_file_to_file( + &zip_archive, zip_filename_str.c_str(), output_file_path.c_str(), 0); + if (b_extract == MZ_FALSE) { + throw std::runtime_error(fmt::format( + "Failed to extract file {} to {}", + zip_filename_str, + output_file_path)); + } // Save the file for bookkeeping size_t extension_idx = output_file_path.find_last_of('.'); @@ -583,6 +553,14 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } } + // Close the zip archive as we have extracted all files to the temp + // directory + if (!mz_zip_reader_end(&zip_archive)) { + throw std::runtime_error( + std::string("Failed to close zip archive: {}") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + if (cpp_filename.empty() && so_filename.empty()) { std::string found_filenames_str; for (const std::string& filename : found_filenames) { From e8af168ee09243dd2179ae1cc5c9e8330e2f5614 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 22:16:03 +0000 Subject: [PATCH 346/457] Revert "[AOTI] normalize path and process model files. (#158705)" This reverts commit ff0da08f4bc5ee135b495926cd58a36a1c0e1a5b. Reverted https://github.com/pytorch/pytorch/pull/158705 on behalf of https://github.com/ZainRizvi due to Sorry but I had to revert this PR in order to revert https://github.com/pytorch/pytorch/pull/158671 ([comment](https://github.com/pytorch/pytorch/pull/158705#issuecomment-3099532516)) --- .../aoti_package/model_package_loader.cpp | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 66568025718af..8e3a2d95fb9ec 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -478,31 +478,27 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::string so_filename; std::string cpp_filename; std::vector obj_filenames; - std::string model_directory = normalize_path_separator( - file_prefix + "data" + k_separator + "aotinductor" + k_separator + - model_name); - std::string const_directory = normalize_path_separator( - file_prefix + "data" + k_separator + "constants"); - - // zip_filename_str can't be normalize_path_separator, because it should be - // as index for mz_zip_reader_extract_file_to_file. - for (auto zip_filename_str : found_filenames) { - auto cur_filename = normalize_path_separator(zip_filename_str); + std::string model_directory = file_prefix + "data" + k_separator + + "aotinductor" + k_separator + model_name; + std::string const_directory = + file_prefix + "data" + k_separator + "constants"; + + for (const std::string& filename_str : found_filenames) { // Only compile files in the specified model directory - if (c10::starts_with(cur_filename, model_directory) || - c10::starts_with(cur_filename, const_directory)) { + if (c10::starts_with(filename_str, model_directory) || + c10::starts_with(filename_str, const_directory)) { std::string output_path_str = temp_dir_; - if (c10::starts_with(cur_filename, model_directory)) { + if (c10::starts_with(filename_str, model_directory)) { output_path_str += k_separator; - output_path_str += cur_filename; - } else { // startsWith(zip_filename_str, const_directory) + output_path_str += filename_str; + } else { // startsWith(filename_str, const_directory) // Extract constants to the same directory as the rest of the files // to be consistent with internal implementation - size_t lastSlash = cur_filename.find_last_of(k_separator); - std::string filename = cur_filename; + size_t lastSlash = filename_str.find_last_of(k_separator); + std::string filename = filename_str; if (lastSlash != std::string::npos) { - filename = cur_filename.substr(lastSlash + 1); + filename = filename_str.substr(lastSlash + 1); } output_path_str.append(k_separator) .append(model_directory) @@ -511,7 +507,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } std::string output_file_path = normalize_path_separator(output_path_str); - LOG(INFO) << "Extract file: " << zip_filename_str << " to " + LOG(INFO) << "Extract file: " << filename_str << " to " << output_file_path; // Create the parent directory if it doesn't exist @@ -530,12 +526,10 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // Extracts file to the temp directory mz_bool b_extract = mz_zip_reader_extract_file_to_file( - &zip_archive, zip_filename_str.c_str(), output_file_path.c_str(), 0); + &zip_archive, filename_str.c_str(), output_file_path.c_str(), 0); if (b_extract == MZ_FALSE) { throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", - zip_filename_str, - output_file_path)); + "Failed to extract file {} to {}", filename_str, output_file_path)); } // Save the file for bookkeeping From 5a56e6a72b8d805a665e67a23b0ca2decf4d8cc2 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 22:18:19 +0000 Subject: [PATCH 347/457] Revert "[AOTI] fix extract file failed on Windows. (#158702)" This reverts commit 7cc1a9546c135f8e7635e0d38aa2bba797f8907d. Reverted https://github.com/pytorch/pytorch/pull/158702 on behalf of https://github.com/ZainRizvi due to Sorry but I had to revert this PR in order to revert https://github.com/pytorch/pytorch/pull/158671 ([comment](https://github.com/pytorch/pytorch/pull/158702#issuecomment-3099556215)) --- .../aoti_package/model_package_loader.cpp | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 8e3a2d95fb9ec..127969c0318ff 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -435,21 +435,19 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::vector found_filenames; for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { - uint32_t zip_filename_len = + uint32_t filename_len = mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); - if (zip_filename_len == 0) { + if (filename_len == 0) { throw std::runtime_error("Failed to read filename"); } - // zip_filename_len returned by mz_zip_reader_get_filename includes the null - // terminator, so we need to subtract 1 here. - std::string zip_filename_str(zip_filename_len - 1, '\0'); - // zip_filename_str can't be normalize_path_separator, because it should be - // as index for mz_zip_reader_extract_file_to_file. + // filename_len returned by mz_zip_reader_get_filename includes the null + // terminator, so we need to subtract 1 here + std::string filename_str(filename_len - 1, '\0'); if (!mz_zip_reader_get_filename( - &zip_archive, i, zip_filename_str.data(), zip_filename_len)) { + &zip_archive, i, filename_str.data(), filename_len)) { throw std::runtime_error("Failed to read filename"); } - found_filenames.push_back(zip_filename_str); + found_filenames.push_back(normalize_path_separator(filename_str)); } if (found_filenames.empty()) { @@ -506,17 +504,18 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } - std::string output_file_path = normalize_path_separator(output_path_str); + output_path_str = normalize_path_separator(output_path_str); + LOG(INFO) << "Extract file: " << filename_str << " to " - << output_file_path; + << output_path_str; // Create the parent directory if it doesn't exist - size_t parent_path_idx = output_file_path.find_last_of(k_separator); + size_t parent_path_idx = output_path_str.find_last_of(k_separator); if (parent_path_idx == std::string::npos) { throw std::runtime_error( - "Failed to find parent path in " + output_file_path); + "Failed to find parent path in " + output_path_str); } - std::string parent_path = output_file_path.substr(0, parent_path_idx); + std::string parent_path = output_path_str.substr(0, parent_path_idx); if (!recursive_mkdir(parent_path)) { throw std::runtime_error(fmt::format( "Failed to create directory {}: {}", @@ -526,22 +525,22 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // Extracts file to the temp directory mz_bool b_extract = mz_zip_reader_extract_file_to_file( - &zip_archive, filename_str.c_str(), output_file_path.c_str(), 0); + &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); if (b_extract == MZ_FALSE) { throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", filename_str, output_file_path)); + "Failed to extract file {} to {}", filename_str, output_path_str)); } // Save the file for bookkeeping - size_t extension_idx = output_file_path.find_last_of('.'); + size_t extension_idx = output_path_str.find_last_of('.'); if (extension_idx != std::string::npos) { - std::string filename_extension = output_file_path.substr(extension_idx); + std::string filename_extension = output_path_str.substr(extension_idx); if (filename_extension == ".cpp") { - cpp_filename = output_file_path; + cpp_filename = output_path_str; } else if (filename_extension == object_file_ext()) { - obj_filenames.push_back(output_file_path); + obj_filenames.push_back(output_path_str); } else if (filename_extension == extension_file_ext()) { - so_filename = output_file_path; + so_filename = output_path_str; } } } From 734826d88e54642f574ea6c0f5e66cf6da6a8157 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 22:20:46 +0000 Subject: [PATCH 348/457] Revert "[AOTI] windows package load dev (#158671)" This reverts commit d42c40976727fed4c9908d4194f26917d0a3da66. Reverted https://github.com/pytorch/pytorch/pull/158671 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @angelayi can you please help them validate the fixes internally? You can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158671#issuecomment-3099570374)) --- .../inductor/aoti_package/model_package_loader.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 127969c0318ff..629dc8cb2ae80 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -471,7 +471,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << found_filenames[1]; } - temp_dir_ = normalize_path_separator(create_temp_dir()); + temp_dir_ = create_temp_dir(); std::string so_filename; std::string cpp_filename; @@ -504,8 +504,6 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } - output_path_str = normalize_path_separator(output_path_str); - LOG(INFO) << "Extract file: " << filename_str << " to " << output_path_str; @@ -524,12 +522,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - mz_bool b_extract = mz_zip_reader_extract_file_to_file( + mz_zip_reader_extract_file_to_file( &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); - if (b_extract == MZ_FALSE) { - throw std::runtime_error(fmt::format( - "Failed to extract file {} to {}", filename_str, output_path_str)); - } // Save the file for bookkeeping size_t extension_idx = output_path_str.find_last_of('.'); From dd0adc9386226fdbfb1ddaf0c1e74de54dfbc83e Mon Sep 17 00:00:00 2001 From: codingwithsurya Date: Sun, 20 Jul 2025 16:39:36 -0700 Subject: [PATCH 349/457] [SymmMem] Add NVSHMEM broadcast support into Triton (#158514) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds broadcast collective operation for distributing data from root PE to all other PEs in NVSHMEM Triton kernels. Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_broadcast`
Quick debug print for sanity check ```markdown ============================================================ [Rank 0] Starting broadcast test with world_size=2 ============================================================ [Rank 0] Configuration: - nelems: 4 - dtype: torch.int64, element_size: 8 bytes - nelems_bytes: 32 ============================================================ [Rank 1] Starting broadcast test with world_size=2 ============================================================ [Rank 1] Configuration: - nelems: 4 - dtype: torch.int64, element_size: 8 bytes - nelems_bytes: 32 [Rank 1] Non-root source data: [-1, -1, -1, -1] [Rank 0] Root source data: [100, 101, 102, 103] [Rank 1] Initial destination: [-999, -999, -999, -999] [Rank 0] Initial destination: [-999, -999, -999, -999] [Rank 0] Executing broadcast operation... [Rank 1] Executing broadcast operation... [Rank 0] Broadcast operation completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 1] Broadcast operation completed /data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. warnings.warn( # warn only once [Rank 1] Results after broadcast: [Rank 0] Results after broadcast: [Rank 1] Destination buffer: [100, 101, 102, 103] [Rank 1] Expected: [100, 101, 102, 103] [Rank 0] Destination buffer: [100, 101, 102, 103] [Rank 0] Expected: [100, 101, 102, 103] [Rank 1] Match: ✓ [Rank 0] Match: ✓ [Rank 1] ============================================================ [Rank 1] Broadcast test PASSED ✓ [Rank 1] Summary: Root PE 0 broadcasted [100, 101, 102, 103] to all PEs [Rank 1] ============================================================ [Rank 0] ============================================================ [Rank 0] Broadcast test PASSED ✓ [Rank 0] Summary: Root PE 0 broadcasted [100, 101, 102, 103] to all PEs [Rank 0] ============================================================ ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158514 Approved by: https://github.com/fduwjj, https://github.com/mandroid6 ghstack dependencies: #158511, #158512, #158513 --- test/distributed/test_nvshmem_triton.py | 58 +++++++++++++++++++ .../_symmetric_memory/_nvshmem_triton.py | 20 +++++++ 2 files changed, 78 insertions(+) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 992c0895714ba..c4565a96496ce 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -220,6 +220,17 @@ def alltoall_kernel( nvshmem.alltoall(team_handle, dest_ptr, src_ptr, nelems) +@triton.jit +def broadcast_kernel( + team_handle, + dest_ptr, + src_ptr, + nelems, + pe_root, +): + nvshmem.broadcast(team_handle, dest_ptr, src_ptr, nelems, pe_root) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -889,6 +900,53 @@ def test_triton_alltoall(self) -> None: actual = dst[i * nelems_per_pe : (i + 1) * nelems_per_pe] torch.testing.assert_close(actual, torch.full_like(actual, expected)) + @skipIfRocm + @requires_triton() + def test_triton_broadcast(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + # Configuration + nelems = 4 # number of elements + dtype = torch.int64 + # Source buffer - only root will have meaningful data + pe_root = 0 # PE 0 will be the root + src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + if rank == pe_root: + # Root fills with specific pattern + for i in range(nelems): + src[i] = 100 + i + else: + # Non-root PEs have dummy data + src.fill_(-1) + # Destination buffer + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Synchronize before broadcast + dist.barrier() + # Execute broadcast + team_handle = 0 # NVSHMEM_TEAM_WORLD + broadcast_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nelems, + pe_root, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Synchronize after broadcast + dist.barrier() + # Verify results - all ranks should have the root's data + expected = [100 + i for i in range(nelems)] + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 08124483a9fe6..dda1885a8e167 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -260,3 +260,23 @@ def alltoall(team, dest, source, nelems, _builder=None): # type: ignore[no-unty is_pure=False, _builder=_builder, ) + + @core.extern + def broadcast(team, dest, source, nelems, pe_root, _builder=None): # type: ignore[no-untyped-def] + """Broadcasts data from a root PE to all other PEs in a team""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nelems, pe_root], + { + ( + core.dtype("int64"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # nelems + core.dtype("int64"), # pe_root + ): ("nvshmem_longlong_broadcast", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) From 4366610f5a18ffe72e947fab9adb5ee072d74b91 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 21 Jul 2025 22:23:40 +0000 Subject: [PATCH 350/457] [c10d] block_current_stream: correctness fixes (#158757) This fixes a number of issues that were present in https://github.com/pytorch/pytorch/pull/156883 as pointed out by @ngimel Test plan: Expanded tests to cover use after free behavior + non-default stream ``` pytest test/distributed/test_c10d_pypg.py -v -k block_current_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158757 Approved by: https://github.com/ngimel --- test/distributed/test_c10d_pypg.py | 96 +++++++++++++------ .../csrc/distributed/c10d/cuda/StreamBlock.cu | 56 +++++++++-- .../distributed/c10d/cuda/StreamBlock.cuh | 1 + 3 files changed, 117 insertions(+), 36 deletions(-) diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py index 6ccb81b116ca1..65faf2075daa6 100644 --- a/test/distributed/test_c10d_pypg.py +++ b/test/distributed/test_c10d_pypg.py @@ -181,6 +181,19 @@ def use_wrapper(self): return True +class BlockWork(dist._Work): + """ + Dummy work that is used to test blocking the current stream. + """ + + def __init__(self): + super().__init__() + self.future_ = torch.futures.Future() + + def get_future(self): + return self.future_ + + class TestPyProcessGroup(TestCase): def test_attr_overrides(self): pg = DummyAttrProcessGroup(0, 1) @@ -202,34 +215,61 @@ def test_abort_shutdown(self) -> None: @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") def test_block_current_stream(self) -> None: - class BlockWork(dist._Work): - def __init__(self): - super().__init__() - self.future_ = torch.futures.Future() - - def get_future(self): - return self.future_ - - # nothing in queue so instantly resolves - event1 = torch.cuda.Event() - event1.record() - time.sleep(0.1) - self.assertTrue(event1.query()) - - work = BlockWork() - work.block_current_stream() - - # stream is blocked so doesn't resolve - event = torch.cuda.Event() - event.record() - time.sleep(0.1) - self.assertFalse(event.query()) - - # resolve the work - work.get_future().set_result(None) - - torch.cuda.current_stream().synchronize() - self.assertTrue(event.query()) + torch.cuda.synchronize() + + stream = torch.cuda.Stream() + with stream: + # nothing in queue so instantly resolves + event1 = torch.cuda.Event() + event1.record() + time.sleep(0.1) + self.assertTrue(event1.query()) + + work = BlockWork() + work.block_current_stream() + + # stream is blocked so doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # resolve the work + work.get_future().set_result(None) + + stream.synchronize() + self.assertTrue(event.query()) + + @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") + def test_block_current_stream_use_after_free(self) -> None: + """ + This tests that the CPU control tensor is not freed before the CUDA kernel executes. + """ + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with stream: + a = BlockWork() + a.block_current_stream() + + b = BlockWork() + b.block_current_stream() + + # unblock b first though a is still blocking + b.get_future().set_result(None) + # delete b + del b + + # a is still blocking so this doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # unblock a + a.get_future().set_result(None) + + stream.synchronize() + self.assertTrue(event.query()) if __name__ == "__main__": diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu index 58533ece6af8b..db4a118a25e59 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -8,7 +10,7 @@ #include #include #else -#include +#include #endif namespace c10d::cuda::detail { @@ -21,19 +23,49 @@ __device__ void nanosleep(int64_t ns) { #endif } +__device__ int32_t load_cpu_int32(int32_t* ptr) { +#if defined(USE_ROCM) + // WARNING: this may not be safe + return atomicAdd_system(ptr, 0); +#else + int32_t current_value = 0; + + // Bypass L1 cache to see updates at L2 and above. + // This could use .cv to bypass L2 cache but that's significantly more + // expensive and the CPU write will clear the L2 cache. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators + asm volatile("ld.cg.s32 %0, [%1];" + : "=r"(current_value) // Output operand + : "l"(ptr) // Input operand + ); + return current_value; +#endif +} + +__device__ void store_cpu_int32(int32_t* ptr, int32_t val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)) + // WARNING: this value may be cached without .release + *ptr = val; +#else + // Releases memory so it can be seen by other threads on the system. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#release-acquire-patterns + asm volatile("st.release.sys.s32 [%0], %1;" ::"l"(ptr), "r"(val)); +#endif +} + __global__ // set launch bounds to limit to 1 thread per block, 1 block per MP __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { - value[1] = StreamBlockStatus::RUNNING; + store_cpu_int32(&value[1], StreamBlockStatus::RUNNING); size_t start = c10d::symmetric_memory::global_timer_ns(); size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds while (true) { // Atomically read the value - int current_value = atomicAdd(&value[0], 0); + int32_t current_value = load_cpu_int32(value); // Check if the value is equal to the expected value if (current_value == 1) { - value[1] = StreamBlockStatus::ABORTED; + store_cpu_int32(&value[1], StreamBlockStatus::ABORTED); return; } @@ -41,7 +73,7 @@ __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { // Check if timeout has been reached size_t now = c10d::symmetric_memory::global_timer_ns(); if ((now - start) > timeout_ns) { - value[1] = StreamBlockStatus::TIMED_OUT; + store_cpu_int32(&value[1], StreamBlockStatus::TIMED_OUT); return; } } @@ -55,13 +87,21 @@ StreamBlock::StreamBlock(std::chrono::milliseconds timeout) : comm_{ // We need to pin the memory since we access the CPU memory directly form // the GPU. - at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() + at::zeros({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() }, timeout_{timeout} { + auto stream = at::cuda::getCurrentCUDAStream(); + auto* ptr = comm_.mutable_data_ptr(); + auto* ctx = comm_.storage().data_ptr().get_context(); + // grid size 1, block size 1, 0 bytes of shared memory - kernel_barrier<<<1, 1, 0>>>( - comm_.mutable_data_ptr(), timeout_.count()); + kernel_barrier<<<1, 1, 0, stream>>>(ptr, timeout_.count()); C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // This object may be deallocated before the CUDA kernel completes. We need to + // register the CPU tensor so it's only freed after the kernel completes + // execution. + at::getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap()); } C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock) diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh index f94f272d7eef6..9ca52b4c5e885 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh @@ -13,6 +13,7 @@ class StreamBlock : public ::c10d::cuda::StreamBlock { StreamBlock(std::chrono::milliseconds timeout); void abort() override { + std::atomic_thread_fence(std::memory_order_seq_cst); comm_[0] = 1; } From cab28330f8c49cdb66d6a299755dc09c87c14a9d Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 21 Jul 2025 10:21:48 -0700 Subject: [PATCH 351/457] Setup TorchBench in Docker (#158613) This reduces the time spending to setup TorchBench in A100/H100 by another half an hour ### Testing * H100 benchmark https://github.com/pytorch/pytorch/actions/runs/16396172453. Once this done, I will review the results on [HUD](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Fri%2C%2011%20Jul%202025%2023%3A01%3A24%20GMT&stopTime=Fri%2C%2018%20Jul%202025%2023%3A01%3A24%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=gh/huydhn/6/head&lCommit=14a38c719b29a19f518239b5edb084838ac5d2fb&rBranch=main&rCommit=0a99b026d6bd0f67dc2c0a20fe3228ddc4144854) to confirm that all models are there * A100 benchmark https://github.com/pytorch/pytorch/actions/runs/16396173932 Signed-off-by: Huy Do Pull Request resolved: https://github.com/pytorch/pytorch/pull/158613 Approved by: https://github.com/janeyx99 --- .ci/docker/build.sh | 2 +- .../docker}/ci_commit_pins/torchbench.txt | 0 .../common/install_inductor_benchmark_deps.sh | 28 +++++++++++++++++-- .ci/docker/requirements-ci.txt | 1 - .ci/docker/ubuntu-rocm/Dockerfile | 3 +- .ci/docker/ubuntu/Dockerfile | 3 +- .ci/pytorch/common_utils.sh | 24 ---------------- .ci/pytorch/test.sh | 22 +++++---------- 8 files changed, 38 insertions(+), 45 deletions(-) rename {.github => .ci/docker}/ci_commit_pins/torchbench.txt (100%) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d6cba6659db7a..d8de423682004 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -276,7 +276,7 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 VISION=yes diff --git a/.github/ci_commit_pins/torchbench.txt b/.ci/docker/ci_commit_pins/torchbench.txt similarity index 100% rename from .github/ci_commit_pins/torchbench.txt rename to .ci/docker/ci_commit_pins/torchbench.txt diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 7312dce170db2..2e0780f889e17 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -15,11 +15,35 @@ function install_timm() { commit=$(get_pinned_commit timm) pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}" - # Clean up - conda_run pip uninstall -y torch torchvision triton +} + +function install_torchbench() { + local commit + commit=$(get_pinned_commit torchbench) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + python install.py --continue_on_fail + + # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 + # is regressing speedup metric. This needs to be investigated further + pip install transformers==4.38.1 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd } # Pango is needed for weasyprint which is needed for doctr conda_install pango + +# Stable packages are ok here, just to satisfy TorchBench check +pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 + +install_torchbench install_huggingface install_timm + +# Clean up +conda_run pip uninstall -y torch torchvision torchaudio triton diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index a7486c40b121d..926f8aef35f27 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -361,7 +361,6 @@ pwlf==2.2.1 #Pinned versions: 2.2.1 #test that import: test_sac_estimator.py - # To build PyTorch itself packaging>=24.2 pyyaml diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 2528da07c69e3..8f2cc6eef9581 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt # (optional) Install non-default Ninja version ARG NINJA_VERSION diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 27c466dd8d41d..077910cef9f35 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt ARG TRITON ARG TRITON_CPU diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 9075fe5fb56f8..046f0e1597e65 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -258,30 +258,6 @@ function clone_pytorch_xla() { fi } -function checkout_install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench - pushd torchbench - git checkout "$commit" - - if [ "$1" ]; then - python install.py --continue_on_fail models "$@" - else - # Occasionally the installation may fail on one model but it is ok to continue - # to install and test other models - python install.py --continue_on_fail - fi - - # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 - # is regressing speedup metric. This needs to be investigated further - pip install transformers==4.38.1 - - echo "Print all dependencies after TorchBench is installed" - python -mpip freeze - popd -} - function install_torchao() { local commit commit=$(get_pinned_commit torchao) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 2e7cc84138cee..e40d36abc2cc2 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1668,13 +1668,11 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then elif [[ "${TEST_CONFIG}" == cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco - PYTHONPATH=$(pwd)/torchbench test_cachebench + PYTHONPATH=/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then install_torchaudio install_torchvision - checkout_install_torchbench nanogpt - PYTHONPATH=$(pwd)/torchbench test_verify_cachebench + PYTHONPATH=/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision @@ -1683,28 +1681,22 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then - checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then - checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \ - llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \ - functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0 - PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf + PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then - checkout_install_torchbench - TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest + TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest else - checkout_install_torchbench # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in if [[ "${TEST_CONFIG}" != *cpu* ]]; then install_torchrec_and_fbgemm fi - PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" + PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then install_torchvision - PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" + PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti fi From b3c868d603e8f7b6661c93cd3d50c9a7b213ad6c Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Mon, 21 Jul 2025 22:41:02 +0000 Subject: [PATCH 352/457] [vllm]Add vllm.txt for pinned commit (#158754) It seems the nightly.yml won't auto-generate txt file when it does not existed, so added the file with latest merged commit from vllm: [vllm commit](https://github.com/vllm-project/vllm/commits/main) Error: https://github.com/pytorch/pytorch/actions/runs/16405915719/job/46351847504 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158754 Approved by: https://github.com/huydhn --- .github/ci_commit_pins/vllm.txt | 1 + .github/merge_rules.yaml | 2 +- .github/workflows/nightly.yml | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 .github/ci_commit_pins/vllm.txt diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt new file mode 100644 index 0000000000000..22adf465e471c --- /dev/null +++ b/.github/ci_commit_pins/vllm.txt @@ -0,0 +1 @@ +29d1ffc5b4c763ef76aff9e3f617fa60dd292418 diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 00b7cb618401a..f87980ed8df33 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -76,8 +76,8 @@ - .github/ci_commit_pins/audio.txt - .github/ci_commit_pins/vision.txt - .github/ci_commit_pins/torchdynamo.txt + - .github/ci_commit_pins/vllm.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/vllm.txt approved_by: - pytorchbot mandatory_checks_name: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 238b897d3da63..7bb1ff9296ab8 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -86,7 +86,7 @@ jobs: - repo-name: vllm repo-owner: vllm-project branch: main - pin-folder: .ci/docker/ci_commit_pins + pin-folder: .github/ci_commit_pins # Allow this to be triggered on either a schedule or on workflow_dispatch to allow for easier testing if: github.repository_owner == 'pytorch' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') steps: From feaa02f9addfc6764843c8b48f8c403de593737c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 22:46:53 +0000 Subject: [PATCH 353/457] Revert "[build] pin `setuptools>=77` to enable PEP 639 (#158104)" This reverts commit a78fb63dbdf98a1db219095293de1a11005e0390. Reverted https://github.com/pytorch/pytorch/pull/158104 on behalf of https://github.com/malfet due to It still breaks inductor-perf-nightly, see https://github.com/pytorch/pytorch/actions/runs/16425364208/job/46417088208, I'm going to dismiss all previous reviews ([comment](https://github.com/pytorch/pytorch/pull/158104#issuecomment-3099706457)) --- .ci/aarch64_linux/aarch64_ci_setup.sh | 2 +- .ci/docker/manywheel/Dockerfile_2_28 | 2 +- .ci/docker/manywheel/Dockerfile_s390x | 5 +++-- .ci/docker/requirements-ci.txt | 11 +++++------ .ci/pytorch/build.sh | 3 --- .ci/pytorch/test.sh | 2 +- .ci/pytorch/win-test-helpers/build_pytorch.bat | 5 ----- .ci/pytorch/win-test.sh | 2 +- .ci/pytorch/windows/internal/install_python.bat | 2 +- .ci/pytorch/windows/setup_build.bat | 5 +---- .ci/wheel/build_wheel.sh | 14 +++++++------- .github/requirements-gha-cache.txt | 2 +- .github/requirements/pip-requirements-macOS.txt | 8 ++++---- .github/scripts/lintrunner.sh | 2 +- .github/scripts/windows/build_triton.bat | 2 +- .github/workflows/_mac-test.yml | 5 ----- pyproject.toml | 11 ++++++++--- requirements-build.txt | 4 ++-- test/dynamo/test_exc.py | 16 ++++++++-------- 19 files changed, 46 insertions(+), 57 deletions(-) diff --git a/.ci/aarch64_linux/aarch64_ci_setup.sh b/.ci/aarch64_linux/aarch64_ci_setup.sh index b18d27f2793fc..8ffba65d7fedd 100755 --- a/.ci/aarch64_linux/aarch64_ci_setup.sh +++ b/.ci/aarch64_linux/aarch64_ci_setup.sh @@ -12,7 +12,7 @@ fi SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" source $SCRIPTPATH/../manywheel/set_desired_python.sh -pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1.4 patchelf==0.17.2 +pip install -q numpy==${NUMPY_VERSION} pyyaml==6.0.2 scons==4.7.0 ninja==1.11.1 patchelf==0.17.2 for tool in python python3 pip pip3 ninja scons patchelf; do ln -sf ${DESIRED_PYTHON_BIN_DIR}/${tool} /usr/local/bin; diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 7f279a1c1a735..b150423e99544 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -128,7 +128,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH # Install setuptools and wheel for python 3.12/3.13 RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ - /opt/python/${cpython_version}/bin/python -m pip install "setuptools>=77.0.0" "packaging>=24.2" wheel; \ + /opt/python/${cpython_version}/bin/python -m pip install setuptools wheel; \ done; diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 335488b88f122..46ec7f77ae8ba 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -124,9 +124,10 @@ RUN python3 -mpip install cmake==3.28.0 # install newest flatbuffers version first: # for some reason old version is getting pulled in otherwise. # packaging package is required for onnxruntime wheel build. -RUN pip3 install 'setuptools>=77.0' 'packaging>=24.2' && \ - pip3 install flatbuffers cython 'pkgconfig>=1.5.5' 'numpy<2.3.0' && \ +RUN pip3 install flatbuffers && \ + pip3 install cython 'pkgconfig>=1.5.5' 'setuptools>=77' 'numpy<2.3.0' && \ pip3 install --no-build-isolation h5py==3.11.0 && \ + pip3 install packaging && \ git clone https://github.com/microsoft/onnxruntime && \ cd onnxruntime && git checkout v1.21.0 && \ git submodule update --init --recursive && \ diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 926f8aef35f27..944b1fb35b36e 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -50,7 +50,7 @@ flatbuffers==24.12.23 hypothesis==5.35.1 # Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136 #Description: advanced library for generating parametrized tests -#Pinned versions: 5.35.1 +#Pinned versions: 3.44.6, 4.53.2 #test that import: test_xnnpack_integration.py, test_pruning_op.py, test_nn.py junitparser==2.1.1 @@ -104,10 +104,10 @@ networkx==2.8.8 #Pinned versions: 2.8.8 #test that import: functorch -ninja==1.11.1.4 +ninja==1.11.1.3 #Description: build system. Used in some tests. Used in build to generate build #time tracing information -#Pinned versions: 1.11.1.4 +#Pinned versions: 1.11.1.3 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py numba==0.49.0 ; python_version < "3.9" @@ -307,7 +307,7 @@ pytest-cpp==2.3.0 #Pinned versions: 2.3.0 #test that import: -z3-solver==4.15.1.0 +z3-solver==4.12.6.0 #Description: The Z3 Theorem Prover Project #Pinned versions: #test that import: @@ -362,10 +362,9 @@ pwlf==2.2.1 #test that import: test_sac_estimator.py # To build PyTorch itself -packaging>=24.2 pyyaml pyzstd -setuptools>=77.0.0 +setuptools>=70.1.0 six scons==4.5.2 ; platform_machine == "aarch64" diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index f2b8998a6f6cd..58454bcb108a7 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -269,9 +269,6 @@ if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" //... fi else - # install build-system requirements before running setup.py commands - python -m pip install -r requirements-build.txt - # check that setup.py would fail with bad arguments echo "The next three invocations are expected to fail with invalid command error messages." ( ! get_exit_code python setup.py bad_argument ) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index e40d36abc2cc2..4f28297b5bce8 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -201,7 +201,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then # JIT C++ extensions require ninja. - pip_install "ninja==1.11.1.4" + pip_install "ninja==1.10.2" # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 74c9183f2abb0..7ceb425ce2d1a 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -126,11 +126,6 @@ if "%USE_CUDA%"=="1" ( set CMAKE_CUDA_COMPILER_LAUNCHER=%TMP_DIR%/bin/randomtemp.exe;%TMP_DIR%\bin\sccache.exe ) -:: Install build-system requirements before running setup.py commands -python -m pip install -r requirements-build.txt -if errorlevel 1 goto fail -if not errorlevel 0 goto fail - :: Print all existing environment variable for debugging set diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index be7f3e4bb35cc..b61dd06ef562c 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -41,7 +41,7 @@ fi python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard==2.13.0 protobuf==5.29.4 pytest-subtests==0.13.1 # Install Z3 optional dependency for Windows builds. -python -m pip install z3-solver==4.15.1.0 +python -m pip install z3-solver==4.12.2.0 # Install tlparse for test\dynamo\test_structured_trace.py UTs. python -m pip install tlparse==0.3.30 diff --git a/.ci/pytorch/windows/internal/install_python.bat b/.ci/pytorch/windows/internal/install_python.bat index 65405a875b6b8..73622bd736edd 100644 --- a/.ci/pytorch/windows/internal/install_python.bat +++ b/.ci/pytorch/windows/internal/install_python.bat @@ -18,5 +18,5 @@ start /wait "" python-amd64.exe /quiet InstallAllUsers=1 PrependPath=0 Include_t if errorlevel 1 exit /b 1 set "PATH=%CD%\Python\Scripts;%CD%\Python;%PATH%" -%PYTHON_EXEC% -m pip install --upgrade pip "setuptools>=77.0.0" "packaging>=24.2" wheel +%PYTHON_EXEC% -m pip install --upgrade pip setuptools packaging wheel if errorlevel 1 exit /b 1 diff --git a/.ci/pytorch/windows/setup_build.bat b/.ci/pytorch/windows/setup_build.bat index df925b4ba90bc..9b492eef664d7 100644 --- a/.ci/pytorch/windows/setup_build.bat +++ b/.ci/pytorch/windows/setup_build.bat @@ -7,9 +7,6 @@ call "internal\install_python.bat" %PYTHON_EXEC% --version set "PATH=%CD%\Python\Lib\site-packages\cmake\data\bin;%CD%\Python\Scripts;%CD%\Python;%PATH%" - -%PYTHON_EXEC% -m pip install "setuptools>=77.0.0" "packaging>=24.2" - if "%DESIRED_PYTHON%" == "3.13t" %PYTHON_EXEC% -m pip install numpy==2.2.1 cmake if "%DESIRED_PYTHON%" == "3.13" %PYTHON_EXEC% -m pip install numpy==2.1.2 cmake if "%DESIRED_PYTHON%" == "3.12" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake @@ -19,7 +16,7 @@ if "%DESIRED_PYTHON%" == "3.9" %PYTHON_EXEC% -m pip install numpy==2.0.2 cmake %PYTHON_EXEC% -m pip install pyyaml %PYTHON_EXEC% -m pip install mkl-include mkl-static -%PYTHON_EXEC% -m pip install boto3 ninja typing-extensions +%PYTHON_EXEC% -m pip install boto3 ninja typing_extensions setuptools==72.1.0 where cmake.exe diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index dc44f8ccc2922..878d6595c84c0 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -127,7 +127,7 @@ export INSTALL_TEST=0 # dont install test binaries into site-packages export MACOSX_DEPLOYMENT_TARGET=10.15 export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -SETUPTOOLS_PINNED_VERSION="==77.0.0" +SETUPTOOLS_PINNED_VERSION="==70.1.0" PYYAML_PINNED_VERSION="=5.3" EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" @@ -135,7 +135,7 @@ RENAME_WHEEL=true case $desired_python in 3.13t) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" CONDA_ENV_CREATE_FLAGS="python-freethreading" @@ -145,31 +145,31 @@ case $desired_python in ;; 3.13) echo "Using 3.13 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.1.0" ;; 3.12) echo "Using 3.12 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=6.0.1" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.11) echo "Using 3.11 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.10) echo "Using 3.10 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; 3.9) echo "Using 3.9 deps" - SETUPTOOLS_PINNED_VERSION=">=77.0.0" + SETUPTOOLS_PINNED_VERSION=">=70.1.0" PYYAML_PINNED_VERSION=">=5.3" NUMPY_PINNED_VERSION="=2.0.2" ;; diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 381bccbee847d..5c691e4bf9b31 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -8,7 +8,7 @@ boto3==1.35.42 jinja2==3.1.6 lintrunner==0.10.7 -ninja==1.11.1.4 +ninja==1.10.0.post1 nvidia-ml-py==11.525.84 pyyaml==6.0 requests==2.32.4 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index ea005956aefa5..9c72c71523b7d 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -7,12 +7,12 @@ hypothesis==6.56.4 librosa>=0.6.2 mpmath==1.3.0 networkx==2.8.7 -ninja==1.11.1.4 +ninja==1.10.2.4 numba==0.59.0 numpy==1.26.4 opt-einsum>=3.3 optree==0.13.0 -packaging==25.0 +packaging==23.1 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 @@ -26,11 +26,11 @@ pytest-xdist==3.3.1 pytest==7.3.2 pyyaml==6.0.2 scipy==1.12.0 -setuptools==80.9.0 +setuptools==72.1.0 sympy==1.13.3 tlparse==0.3.30 tensorboard==2.13.0 typing-extensions==4.12.2 unittest-xml-reporting<=3.2.0,>=2.0.0 xdoctest==1.1.0 -z3-solver==4.15.1.0 +z3-solver==4.12.2.0 diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index 1411ff0397b53..ef4741444f942 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -2,7 +2,7 @@ set -ex # Use uv to speed up lintrunner init -python3 -m pip install -U uv setuptools +python3 -m pip install uv==0.1.45 setuptools CACHE_DIRECTORY="/tmp/.lintbin" # Try to recover the cached binaries diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index da2e86b40432a..97cd535a49889 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -10,7 +10,7 @@ if "%PY_VERS%" == "3.13t" ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) :: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==78.1.1 ninja +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja dir "%VC_INSTALL_PATH%" diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 8822aaf7df418..063c97e449c75 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -80,11 +80,6 @@ jobs: run: | sysctl machdep.cpu.brand_string kern.osproductversion - - name: Install build toolchain - run: | - brew update --quiet - brew install --formula cmake ninja - - name: Clean up leftover processes on MacOS pet runner continue-on-error: true run: | diff --git a/pyproject.toml b/pyproject.toml index 133da9289f5c9..b41ae87621f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,13 @@ [build-system] requires = [ + # 70.1.0: min version for integrated bdist_wheel command from wheel package # 77.0.0: min version for SPDX expression support for project.license - "setuptools>=77.0.0,<80.0", + "setuptools>=70.1.0,<80.0", "cmake>=3.27", "ninja", "numpy", - "packaging>=24.2", + "packaging", "pyyaml", "requests", "six", # dependency chain: NNPACK -> PeachPy -> six @@ -20,7 +21,11 @@ name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" requires-python = ">=3.9,<3.14" -license = "BSD-3-Clause" +# TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 +# FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. +# TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed +# to an error on 2026.02.18. See also: https://github.com/pypa/setuptools/issues/4903 +license = { text = "BSD-3-Clause" } authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] keywords = ["pytorch", "machine learning"] classifiers = [ diff --git a/requirements-build.txt b/requirements-build.txt index 12332b0e1af01..be19d987f73db 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,9 +1,9 @@ # Build System requirements -setuptools>=77.0.0,<80.0 # setuptools develop deprecated on 80.0 +setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 cmake>=3.27 ninja numpy -packaging>=24.2 +packaging pyyaml requests six # dependency chain: NNPACK -> PeachPy -> six diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index c340a2882d471..acc3fd55f6fb0 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -251,13 +251,13 @@ def fn(x, shape): Model: ==> L['shape'][0]: 0 - ==> L['shape'][1]: 0 - ==> L['shape'][2]: 0 + ==> L['shape'][1]: 1 + ==> L['shape'][2]: 1 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 - ==> s3: 0 - ==> s52: 0 + ==> s3: 1 + ==> s52: 1 ==> s77: 3 ==> s86: 0 @@ -315,16 +315,16 @@ def fn(x, shape): %split : [num_users=3] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {}) Model: - ==> L['shape'][0]: 0 - ==> L['shape'][1]: 0 + ==> L['shape'][0]: 1 + ==> L['shape'][1]: 1 ==> L['shape'][2]: 0 ==> L['x'].size()[0]: 3 ==> L['x'].storage_offset(): 0 ==> L['x'].stride()[0]: 1 ==> s3: 0 - ==> s52: 0 + ==> s52: 1 ==> s77: 3 - ==> s86: 0 + ==> s86: 1 Assertions: ==> (== 0 L['x'].storage_offset()) From f09a484b8164aaadd57a79354f0ccf47733f365e Mon Sep 17 00:00:00 2001 From: Ketan Ambati Date: Mon, 21 Jul 2025 22:49:23 +0000 Subject: [PATCH 354/457] Remove is_arvr_mode() from xnnpack.buck.bzl (#158682) Summary: **Changes** * Deleted function import from build definition utilities * Removed `load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")` * Replaced is_arvr_mode() function calls with direct references to configuration flags * Changed from `is_arvr_mode()` to `"ovr_config//build_mode:arvr_mode"` * Changed conditional expressions to Buck `select()` statements Test Plan: Check if CI passes Rollback Plan: Differential Revision: D78520947 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158682 Approved by: https://github.com/malfet --- third_party/xnnpack.buck.bzl | 779 ++++++++++++++++++++++------------- 1 file changed, 485 insertions(+), 294 deletions(-) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 231384bd859ab..0f50efc032591 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,5 +1,4 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") -load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") load( @@ -142,9 +141,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -160,12 +162,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -206,9 +211,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse2", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse2"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse2"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse2"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse2"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -224,12 +232,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse2"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse2"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -270,9 +281,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_ssse3", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("ssse3"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("ssse3"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("ssse3"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("ssse3"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -288,12 +302,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("ssse3"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("ssse3"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -334,9 +351,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse41", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse41"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse41"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse41"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse41"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -352,12 +372,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse41"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse41"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -398,9 +421,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -424,12 +450,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -471,9 +500,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnnigfni", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnnigfni"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnnigfni"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnnigfni"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnnigfni"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -513,12 +545,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vnnigfni"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vnnigfni"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -563,9 +598,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnni", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnni"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnni"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnni"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnni"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -604,12 +642,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vnni"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vnni"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, exported_preprocessor_flags = [ @@ -657,7 +698,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avxvnni", - srcs = prod_srcs_for_arch_wrapper("avxvnni") if is_arvr_mode() else [], + srcs = select({ + "DEFAULT": [], + "ovr_config//build_mode:arvr_mode": prod_srcs_for_arch_wrapper("avxvnni"), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -679,12 +723,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avxvnni"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avxvnni"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -724,9 +771,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_f16c", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("f16c"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("f16c"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("f16c"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("f16c"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -750,12 +800,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("f16c"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("f16c"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -799,9 +852,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fma3", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("fma3"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("fma3"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("fma3"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("fma3"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -828,12 +884,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("fma3"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("fma3"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -889,9 +948,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx2", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx2"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx2"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx2"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx2"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -921,12 +983,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx2"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx2"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -989,9 +1054,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512f"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512f"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512f"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512f"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1015,12 +1083,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512f"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512f"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1034,9 +1105,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vbmi", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vbmi"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vbmi"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vbmi"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vbmi"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1075,12 +1149,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vbmi"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vbmi"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1136,9 +1213,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512skx", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512skx"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512skx"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512skx"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512skx"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1174,12 +1254,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512skx"), - ), - ] if not is_arvr_mode() else []), + platform_srcs = select({ + "DEFAULT": [ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512skx"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1255,8 +1338,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_armsimd32", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("armsimd32"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("armsimd32"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1277,12 +1363,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("armsimd32"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("armsimd32"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1296,9 +1385,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neon"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neon"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1323,16 +1415,19 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neon"), - ), - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neon"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neon"), + ), + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neon"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1346,20 +1441,26 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon_aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon_aarch64"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", ], - platform_srcs = [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neon_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neon_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1374,8 +1475,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fma", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1407,12 +1511,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neonfma"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neonfma"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1426,8 +1533,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neonfma_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1435,12 +1545,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", ], labels = labels, - platform_srcs = [ - ( - "(arm64|aarch64)$", - prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(arm64|aarch64)$", + prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -1455,9 +1568,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("fp16arith"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("fp16arith"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("fp16arith"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("fp16arith"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1504,16 +1620,19 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("fp16arith"), - ), - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("fp16arith") + prod_srcs_for_arch_wrapper("fp16arith_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("fp16arith"), + ), + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("fp16arith") + prod_srcs_for_arch_wrapper("fp16arith_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1527,9 +1646,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1554,16 +1676,19 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonfp16"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonfp16"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonfp16"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonfp16"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1577,9 +1702,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_v8", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonv8"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonv8"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonv8"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonv8"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1618,16 +1746,19 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonv8"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonv8"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonv8"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonv8"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1641,8 +1772,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondot"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondot"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1667,12 +1801,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neondot"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neondot"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1686,8 +1823,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1706,12 +1846,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1725,8 +1868,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondotfp16arith"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondotfp16arith"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1750,12 +1896,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neondotfp16arith"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neondotfp16arith"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1770,8 +1919,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_fp16arith_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1791,12 +1943,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1811,8 +1966,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16arith"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16arith"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1837,12 +1995,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neonfp16arith"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neonfp16arith"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1856,8 +2017,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16arith_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1876,12 +2040,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1895,9 +2062,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neonfma_i8mm", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma_i8mm"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma_i8mm"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma_i8mm"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma_i8mm"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1931,16 +2101,19 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonfma_i8mm"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonfma_i8mm"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonfma_i8mm"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonfma_i8mm"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -1955,8 +2128,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neoni8mm", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neoni8mm"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neoni8mm"), + }), + }), headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1977,12 +2153,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neoni8mm"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neoni8mm"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -1997,8 +2176,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_asm_aarch32", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("aarch32"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("aarch32"), + }), + }), headers = subdir_glob([ ("XNNPACK/src", "xnnpack/assembly.h"), ("XNNPACK/src", "**/*.S"), @@ -2026,12 +2208,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("aarch32"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("aarch32"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -2046,8 +2231,11 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_asm_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("aarch64"), - }) if is_arvr_mode() else [], + "ovr_config//build_mode:arvr_mode": select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("aarch64"), + }), + }), headers = subdir_glob([ ("XNNPACK/src", "xnnpack/assembly.h"), ("XNNPACK/src", "**/*.S"), @@ -2071,12 +2259,15 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("aarch64"), - ), - ] if not is_arvr_mode() else [], + platform_srcs = select({ + "DEFAULT": [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("aarch64"), + ), + ], + "ovr_config//build_mode:arvr_mode": [], + }), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, From 2bb684304d26804ab87103ada05b6ba63e309b59 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 21 Jul 2025 22:51:03 +0000 Subject: [PATCH 355/457] Fix the typos in the right nav by pulling the latest theme (#158746) This will fix broken links in the right nav. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158746 Approved by: https://github.com/malfet --- .ci/docker/requirements-docs.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 1dd883a55a41b..8ff9f07c84a8d 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -4,8 +4,8 @@ sphinx==5.3.0 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering -# but it doesn't seem to work and hangs around idly. The initial thought it is probably -# something related to Docker setup. We can investigate this later +# but it doesn't seem to work and hangs around idly. The initial thought that it is probably +# something related to Docker setup. We can investigate this later. sphinxcontrib.katex==0.8.6 #Description: This is used to generate PyTorch docs From 1227ed6674100f6efb3f7b0e359c51383397c354 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 21 Jul 2025 23:14:19 +0000 Subject: [PATCH 356/457] [dynamic shapes] fix _maybe_evaluate_static axioms bug (#158672) Summary: couldn't get a minimal repro, but xref for size change during dict iteration error: https://fb.workplace.com/groups/1075192433118967/posts/1709439696360901 Test Plan: - Rollback Plan: Differential Revision: D78047846 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158672 Approved by: https://github.com/bobrenjc93 --- torch/fx/experimental/symbolic_shapes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e38e5f777d669..e2e91624db95b 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -6263,7 +6263,7 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: return self._resimplify_floor_div_axioms = False new_items = {} - for k, v in axioms.items(): + for k, v in list(axioms.items()): # A FloorDiv in implications could have became CleanDiv at this point, due to new facts # to the shapeEnv. This handles such issue but its not ideal. This is the only expression # simplification that depends on the global state of shape env. From 15a50dcf1c9492354819179da4bc994014537ab9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 23:14:55 +0000 Subject: [PATCH 357/457] Revert "[BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427)" This reverts commit eb7365072315be2bc4259114e25e269801441748. Reverted https://github.com/pytorch/pytorch/pull/158427 on behalf of https://github.com/ZainRizvi due to Reverting this as part of reverting the stack for https://github.com/pytorch/pytorch/pull/158288 ([comment](https://github.com/pytorch/pytorch/pull/158427#issuecomment-3099815367)) --- build_variables.bzl | 1 - c10/core/impl/PyInterpreter.h | 20 + c10/core/impl/PyInterpreterHooks.cpp | 32 - c10/core/impl/PyInterpreterHooks.h | 39 - c10/core/impl/PyObjectSlot.cpp | 5 + c10/core/impl/PyObjectSlot.h | 20 +- functorch/csrc/dim/dim.cpp | 6136 +++++++++++------------ torch/_dynamo/trace_rules.py | 1 + torch/csrc/Module.cpp | 10 +- torch/csrc/PyInterpreter.cpp | 6 +- torch/csrc/PyInterpreter.h | 2 +- torch/csrc/PyInterpreterHooks.cpp | 20 - torch/csrc/PyInterpreterHooks.h | 15 - torch/csrc/Storage.cpp | 47 +- torch/csrc/Storage.h | 1 + torch/csrc/StorageMethods.cpp | 5 +- torch/csrc/StorageSharing.cpp | 18 +- torch/csrc/autograd/python_variable.cpp | 62 +- torch/csrc/utils/python_dispatch.cpp | 18 +- 19 files changed, 3030 insertions(+), 3428 deletions(-) delete mode 100644 c10/core/impl/PyInterpreterHooks.cpp delete mode 100644 c10/core/impl/PyInterpreterHooks.h delete mode 100644 torch/csrc/PyInterpreterHooks.cpp delete mode 100644 torch/csrc/PyInterpreterHooks.h diff --git a/build_variables.bzl b/build_variables.bzl index d633a29c5b634..f6fba33dc4d4b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -864,7 +864,6 @@ libtorch_python_core_sources = [ "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/PyInterpreter.cpp", - "torch/csrc/PyInterpreterHooks.cpp", "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 09d4801f7d83d..43492443c530c 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -240,4 +240,24 @@ struct C10_API PyInterpreter { void disarm() noexcept; }; +// PyInterpreterStatus describes what the state of its interpreter tag +// is, relative to the thread currently holding the GIL. +enum class PyInterpreterStatus { + // We just allocated the Tensor, it hasn't escaped to other threads, + // we know that it definitely hasn't been tagged to be associated + // with an interpreter. + DEFINITELY_UNINITIALIZED, + // We queried the interpreter field and it looked uninitialized. But + // another thread may have raced with us to tag it with some other + // interpreter id. So we will have to do a CEX to make sure we can + // actually nab it. + MAYBE_UNINITIALIZED, + // We queried the interpreter field and it was tagged to belong to us. + // This means we have sole write access (as we hold the GIL for this + // interpreter) + TAGGED_BY_US, + // Someone else tagged this. We can't use this TensorImpl from Python. + TAGGED_BY_OTHER, +}; + } // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.cpp b/c10/core/impl/PyInterpreterHooks.cpp deleted file mode 100644 index bd5325cf49c20..0000000000000 --- a/c10/core/impl/PyInterpreterHooks.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include - -namespace c10::impl { - -// Define the registry -C10_DEFINE_REGISTRY( - PyInterpreterHooksRegistry, - PyInterpreterHooksInterface, - PyInterpreterHooksArgs) - -const PyInterpreterHooksInterface& getPyInterpreterHooks() { - auto create_impl = [] { -#if !defined C10_MOBILE - auto hooks = PyInterpreterHooksRegistry()->Create( - "PyInterpreterHooks", PyInterpreterHooksArgs{}); - if (hooks) { - return hooks; - } -#endif - // Return stub implementation that will throw errors when methods are called - return std::make_unique(); - }; - static auto hooks = create_impl(); - return *hooks; -} - -// Main function to get global PyInterpreter -PyInterpreter* getGlobalPyInterpreter() { - return getPyInterpreterHooks().getPyInterpreter(); -} - -} // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.h b/c10/core/impl/PyInterpreterHooks.h deleted file mode 100644 index 32a17ad9a8a0c..0000000000000 --- a/c10/core/impl/PyInterpreterHooks.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace c10::impl { - -// Minimal interface for PyInterpreter hooks -struct C10_API PyInterpreterHooksInterface { - virtual ~PyInterpreterHooksInterface() = default; - - // Get the PyInterpreter instance - // Stub implementation throws error when Python is not available - virtual PyInterpreter* getPyInterpreter() const { - TORCH_CHECK( - false, - "PyTorch was compiled without Python support. " - "Cannot access Python interpreter from C++."); - } -}; - -struct C10_API PyInterpreterHooksArgs{}; - -C10_DECLARE_REGISTRY( - PyInterpreterHooksRegistry, - PyInterpreterHooksInterface, - PyInterpreterHooksArgs); - -#define REGISTER_PYTHON_HOOKS(clsname) \ - C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname) - -// Get the global PyInterpreter hooks instance -C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks(); - -C10_API PyInterpreter* getGlobalPyInterpreter(); - -} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 0f1bfb2110747..62af2eae8e37a 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -34,6 +34,11 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { reinterpret_cast(pyobj_) & ~0x1ULL); } +void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load()); + pyobj_ = nullptr; +} + PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter) { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 58b2490eba001..af8b9fa4d0ec7 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,7 +2,6 @@ #include #include -#include #include #include @@ -25,9 +24,11 @@ struct C10_API PyObjectSlot { // // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after // PyObject if necessary! - void init_pyobj(PyObject* pyobj) { - pyobj_interpreter_.store( - getGlobalPyInterpreter(), std::memory_order_relaxed); + void init_pyobj( + PyInterpreter* self_interpreter, + PyObject* pyobj, + PyInterpreterStatus status) { + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); pyobj_ = pyobj; } @@ -52,10 +53,9 @@ struct C10_API PyObjectSlot { // // NB: this lives in header so that we can avoid actually creating the // std::optional - - // @todo alban: I'm not too sure what's going on here, we can probably delete - // it but it's worthwhile making sure - std::optional check_pyobj(bool ignore_hermetic_tls = false) const { + std::optional check_pyobj( + PyInterpreter* self_interpreter, + bool ignore_hermetic_tls = false) const { impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter == nullptr) { @@ -69,6 +69,10 @@ struct C10_API PyObjectSlot { } } + // Clear the PyObject field for an interpreter, in situations where we + // statically know the tensor is tagged with our interpreter. + void unchecked_clear_pyobj(PyInterpreter* interpreter); + PyInterpreter& load_pyobj_interpreter() const; bool owns_pyobj(); diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 8f1e561e2051b..19270d2f9225d 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -6,6 +6,7 @@ #include + // Many APIs have changed/don't exist anymore #if IS_PYTHON_3_12_PLUS @@ -13,25 +14,24 @@ // Re-enable this some day PyObject* Dim_init() { - PyErr_SetString( - PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); - return nullptr; + PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); + return nullptr; } #else +#include "minpybind.h" #include #include -#include -#include #include +#include +#include #include -#include "minpybind.h" -// #include -#include +//#include +#include #include #include -#include +#include #include #include "arena.h" #include "dim.h" @@ -71,3498 +71,3115 @@ PyTypeObject* DimType = nullptr; PyObject* Tensor_getitem(PyObject* self, PyObject* index); int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); -namespace { +namespace{ void maybeInitializeGlobals() { - // globals that depend on the python dim library, - // which we can't lookup until we finish initializing the _C module - if (_Tensor.ptr()) { - return; - } - auto dim = mpy::import("functorch.dim"); - _Tensor = dim.attr("_Tensor"); - pointwise = dim.attr("pointwise"); - _Tensor_sum = _Tensor.attr("sum"); - DimType = (PyTypeObject*)mpy::import("functorch.dim").attr("Dim").ptr(); + // globals that depend on the python dim library, + // which we can't lookup until we finish initializing the _C module + if (_Tensor.ptr()) { + return; + } + auto dim = mpy::import("functorch.dim"); + _Tensor = dim.attr("_Tensor"); + pointwise = dim.attr("pointwise"); + _Tensor_sum = _Tensor.attr("sum"); + DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr(); } void replaceMappingIfMatches(mpy::handle tp) { - auto T = (PyTypeObject*)tp.ptr(); - bool recurse = false; - if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { - T->tp_as_mapping->mp_subscript = Tensor_getitem; - recurse = true; - } - if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { - T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; - recurse = true; - } - if (recurse) { - auto result = tp.attr("__subclasses__").call(); - mpy::list_view lv(result); - for (auto i : lv.enumerate()) { - replaceMappingIfMatches(lv[i]); - } - } -} - -void initializeGlobals(Arena& A) { - auto torch = mpy::import("torch"); - torch_Tensor = (PyTypeObject*)torch.attr("Tensor").ptr(); - torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); - - torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); - torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); - torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - auto TensorBase = (PyTypeObject*)py_TensorBase.ptr(); - THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; - THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; - NamedTuple = mpy::import("typing").attr("NamedTuple"); - no_slice = PySlice_New(NULL, NULL, NULL); + auto T = (PyTypeObject*) tp.ptr(); + bool recurse = false; + if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { + T->tp_as_mapping->mp_subscript = Tensor_getitem; + recurse = true; + } + if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { + T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; + recurse = true; + } + if (recurse) { + auto result = tp.attr("__subclasses__").call(); + mpy::list_view lv(result); + for (auto i : lv.enumerate()) { + replaceMappingIfMatches(lv[i]); + } + } +} + +void initializeGlobals(Arena & A) { + auto torch = mpy::import("torch"); + torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr(); + torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); + + torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); + torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); + torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); + THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; + THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; + NamedTuple = mpy::import("typing").attr("NamedTuple"); + no_slice = PySlice_New(NULL, NULL, NULL); + } mpy::handle DimensionBindError_; mpy::handle DimensionBindError() { - if (!DimensionBindError_.ptr()) { - DimensionBindError_ = - mpy::import("functorch.dim").attr("DimensionBindError"); - } - return DimensionBindError_; + if(!DimensionBindError_.ptr()) { + DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); + } + return DimensionBindError_; } static int64_t n_dims_created = 65; struct Dim : public mpy::base { - int64_t level_; // for stable comparisons in prototype - mpy::object name_; - Dim() : level_(n_dims_created++) {} - void init(mpy::object name, int64_t s = -1) { - name_ = std::move(name); - size_ = s; - } - - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == DimType; - } - - int64_t size() const { - if (size_ == -1) { - mpy::raise_error( - PyExc_ValueError, "dimension %S is unbound", name_.ptr()); - } - return size_; - } - void set_size(int64_t v) { - if (size_ == -1) { - size_ = v; - } else if (size_ != v) { - mpy::raise_error( - DimensionBindError(), - "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", - this, - this->size_, - v); - } - } - bool is_bound() const { - return size_ != -1; - } - static mpy::obj create(mpy::object name, int64_t s = -1) { - if (!DimType) { - maybeInitializeGlobals(); - } - auto r = Dim::alloc(DimType); - r->init(std::move(name), s); - return r; - } - static PyTypeObject Type; - const at::Tensor& range() { - if (!range_.defined()) { - range_ = at::arange(size()); - } - return range_; - } - const at::Tensor& batchtensor() { - if (!batchtensor_.defined()) { - batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); - } - return batchtensor_; - } - - private: - int64_t size_{-1}; - at::Tensor range_; - at::Tensor batchtensor_; + int64_t level_; // for stable comparisons in prototype + mpy::object name_; + Dim() + : level_(n_dims_created++) {} + void init(mpy::object name, int64_t s = -1) { + name_ = std::move(name); + size_ = s; + } + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == DimType; + } + + int64_t size() const { + if (size_ == -1) { + mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr()); + } + return size_; + } + void set_size(int64_t v) { + if (size_ == -1) { + size_ = v; + } else if(size_ != v) { + mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v); + } + } + bool is_bound() const { + return size_ != -1; + } + static mpy::obj create(mpy::object name, int64_t s = -1) { + if (!DimType) { + maybeInitializeGlobals(); + } + auto r = Dim::alloc(DimType); + r->init(std::move(name), s); + return r; + } + static PyTypeObject Type; + const at::Tensor& range() { + if (!range_.defined()) { + range_ = at::arange(size()); + } + return range_; + } + const at::Tensor& batchtensor() { + if (!batchtensor_.defined()) { + batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); + } + return batchtensor_; + } +private: + int64_t size_{-1}; + at::Tensor range_; + at::Tensor batchtensor_; }; + struct DimEntry { - // union of either a negative number indicating which dimension this is from - // the rhs, or a pointer to a first-class dimension. pointers do not have - // their highest bit set, so checking the number is negative tells us that it - // is not a dim. - bool is_positional() const { - return data_ < 0; - } - bool is_none() const { - return data_ == 0; - } - int64_t position() const { - return data_; - } - mpy::hdl dim() const { - Dim* result; - std::memcpy(&result, &data_, sizeof(Dim*)); - return mpy::hdl(result); - } - - DimEntry() : data_(0) {} - - DimEntry(int64_t pos) : data_(pos) { - AT_ASSERT(pos < 0); - } - DimEntry(mpy::hdl d) { - std::memcpy(&data_, &d, sizeof(int64_t)); - } - bool operator==(const DimEntry& rhs) const { - return data_ == rhs.data_; - } - - private: - int64_t data_; + // union of either a negative number indicating which dimension this is from the rhs, + // or a pointer to a first-class dimension. + // pointers do not have their highest bit set, so checking the number is negative tells us + // that it is not a dim. + bool is_positional() const { + return data_ < 0; + } + bool is_none() const { + return data_ == 0; + } + int64_t position() const { + return data_; + } + mpy::hdl dim() const { + Dim* result; + std::memcpy(&result, &data_, sizeof(Dim*)); + return mpy::hdl(result); + } + + DimEntry() + : data_(0) {} + + DimEntry(int64_t pos) + : data_(pos) { + AT_ASSERT(pos < 0); + } + DimEntry(mpy::hdl d) { + std::memcpy(&data_, &d, sizeof(int64_t)); + } + bool operator==(const DimEntry& rhs) const { + return data_ == rhs.data_; + } +private: + int64_t data_; }; // Dim wrapper methods DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { - if (Dim::check(d)) { - if (keepdim) { - mpy::raise_error( - PyExc_ValueError, - "cannot preserve first-class dimensions with keepdim=True"); - } - return Dim::unchecked_wrap(d); - } else if (mpy::is_int(d)) { - auto i = mpy::to_int(d); - while (i >= 0) { - i -= N; - } - return i; - } else { - return DimEntry(); - } -} - -int Dim_init(mpy::hdl self, PyObject* args, PyObject* kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"name", "size", nullptr}; - mpy::handle name; - mpy::handle size = nullptr; - if (!PyArg_ParseTupleAndKeywords( - args, kwds, "O|O", const_cast(kwlist), &name, &size)) { - return -1; - } - self->init( - mpy::object::borrow(name), - (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); - return 0; - PY_END(-1) + if (Dim::check(d)) { + if (keepdim) { + mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True"); + } + return Dim::unchecked_wrap(d); + } else if (mpy::is_int(d)) { + auto i = mpy::to_int(d); + while (i >= 0) { + i -= N; + } + return i; + } else { + return DimEntry(); + } +} + + +int Dim_init(mpy::hdl self, PyObject *args, PyObject *kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"name", "size", nullptr}; + mpy::handle name; + mpy::handle size = nullptr; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast(kwlist), &name, &size)) { + return -1; + } + self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); + return 0; + PY_END(-1) } PyObject* Dim_repr(Dim* self) { - PY_BEGIN - mpy::object name = (self->name_.ptr()) - ? self->name_ - : mpy::unicode_from_string(""); - return name.release(); - PY_END(nullptr) + PY_BEGIN + mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string(""); + return name.release(); + PY_END(nullptr) } + PyObject* Dim_getsize(Dim* self, void*) { - PY_BEGIN - return mpy::from_int(self->size()).release(); - PY_END(nullptr) + PY_BEGIN + return mpy::from_int(self->size()).release(); + PY_END(nullptr) } int Dim_setsize(Dim* self, PyObject* size, void*) { - PY_BEGIN - self->set_size(mpy::to_int(size)); - return 0; - PY_END(-1) + PY_BEGIN + self->set_size(mpy::to_int(size)); + return 0; + PY_END(-1) } PyObject* Dim_getis_bound(Dim* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } PyObject* Dim_getlevel(Dim* self, void*) { - return PyLong_FromLong(self->level_); + return PyLong_FromLong(self->level_); } PyObject* Dim_get_levels(Dim* self, void*) { - mpy::tuple t(1); - t.set(0, mpy::object::borrow(self->ptr())); - return t.release(); + mpy::tuple t(1); + t.set(0, mpy::object::borrow(self->ptr())); + return t.release(); } PyObject* Dim_get_has_device(Dim* self, void*) { - Py_RETURN_FALSE; + Py_RETURN_FALSE; } PyObject* Dim_get_tensor(Dim* self, void*) { - return THPVariable_Wrap(self->range()); + return THPVariable_Wrap(self->range()); } PyObject* Dim_get_batchtensor(Dim* self, void*) { - return THPVariable_Wrap(self->batchtensor()); + return THPVariable_Wrap(self->batchtensor()); } + PyGetSetDef Dim_getsetters[] = { - {"size", (getter)Dim_getsize, (setter)Dim_setsize, "Dimension size", NULL}, - {"is_bound", (getter)Dim_getis_bound, NULL, "is_bound", NULL}, - {"_level", (getter)Dim_getlevel, NULL, "_level", NULL}, - {"_levels", (getter)Dim_get_levels, NULL, "_levels", NULL}, - {"_has_device", (getter)Dim_get_has_device, NULL, "_has_device", NULL}, - {"_tensor", (getter)Dim_get_tensor, NULL, "_tensor", NULL}, - {"_batchtensor", (getter)Dim_get_batchtensor, NULL, "_batchtensor", NULL}, - {"ndim", - (getter)[](PyObject* self, void*) - ->PyObject* {return mpy::from_int(1).release(); -} // namespace -, NULL, "ndim", NULL -} -, { - NULL -} /* Sentinel */ -} -; + {"size", (getter) Dim_getsize, (setter) Dim_setsize, + "Dimension size", NULL}, + {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL}, + {"_level", (getter) Dim_getlevel, NULL, "_level", NULL}, + {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL}, + {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL}, + {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL}, + {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL}, + {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL}, + {NULL} /* Sentinel */ +}; } PyTypeObject Dim::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Dim", /* tp_name */ - sizeof(Dim), /* tp_basicsize */ - 0, /* tp_itemsize */ - Dim::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)Dim_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - "Dim Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - Dim_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)(void*)static_cast, PyObject*, PyObject*)>( - Dim_init), /* tp_init */ - 0, /* tp_alloc */ - Dim::new_stub, /* tp_new */ + "_C.Dim", /* tp_name */ + sizeof(Dim), /* tp_basicsize */ + 0, /* tp_itemsize */ + Dim::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)Dim_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Dim Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Dim_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)(void*)static_cast,PyObject*,PyObject*)>(Dim_init), /* tp_init */ + 0, /* tp_alloc */ + Dim::new_stub, /* tp_new */ }; // class DimList ------------ struct DimList : public mpy::base { - mpy::object name_; - std::vector> dims_; - static PyTypeObject Type; - void init(mpy::object name) { - name_ = std::move(name); - } - void set_dims(std::vector> dims) { - bound_ = true; - dims_ = std::move(dims); - } - bool is_bound() { - return bound_; - } - void bind_len(int64_t size) { - if (bound_) { - int64_t b_size = dims_.size(); - if (b_size != size) { - mpy::raise_error( - DimensionBindError(), - "Dimlist has size %lld but it is being bound to size %d", - b_size, - size); - } - } else { - bound_ = true; - dims_.resize(size); - for (Py_ssize_t i = 0; i < size; ++i) { - dims_[i] = - Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); - } - } - } - int64_t size() const { - if (!bound_) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - return dims_.size(); - } - void set_bound(bool b) { - bound_ = b; - } - - private: - bool bound_ = false; + mpy::object name_; + std::vector> dims_; + static PyTypeObject Type; + void init(mpy::object name) { + name_ = std::move(name); + } + void set_dims(std::vector> dims) { + bound_ = true; + dims_ = std::move(dims); + } + bool is_bound() { + return bound_; + } + void bind_len(int64_t size) { + if (bound_) { + int64_t b_size = dims_.size(); + if (b_size != size) { + mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size); + } + } else { + bound_ = true; + dims_.resize(size); + for (Py_ssize_t i = 0; i < size; ++i) { + dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); + } + } + } + int64_t size() const { + if (!bound_) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + return dims_.size(); + } + void set_bound(bool b) { + bound_ = b; + } +private: + bool bound_ = false; }; -static int DimList_init(DimList* self, PyObject* args, PyObject* kwds); + +static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); static PyObject* DimList_repr(DimList* self) { - PY_BEGIN - if (self->is_bound()) { - size_t size = self->dims_.size(); - mpy::tuple t(size); - for (size_t i = 0; i < size; ++i) { - t.set(i, self->dims_[i]); - } - return mpy::repr(t).release(); - } else if (!mpy::is_none(self->name_)) { - return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); - } else { - return mpy::unicode_from_string("").release(); - } - PY_END(nullptr) -} - -static PyObject* DimList_bind( - DimList* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - mpy::handle sizes; - static const char* const _keywords[] = {"sizes", nullptr}; - static _PyArg_Parser parser = {"O", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { - return nullptr; - } - if (!mpy::is_sequence(sizes)) { - mpy::raise_error(PyExc_ValueError, "expected a sequence"); - } - mpy::sequence_view seq = sizes; - auto size = seq.size(); - self->bind_len(size); - for (Py_ssize_t i = 0; i < size; ++i) { - self->dims_[i]->set_size(mpy::to_int(seq[i])); - } - Py_RETURN_NONE; - PY_END(nullptr) -} - -static PyObject* DimList_bind_len( - DimList* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - int size; - static const char* const _keywords[] = {"N", nullptr}; - static _PyArg_Parser parser = {"i", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { - return nullptr; - } - self->bind_len(size); - Py_RETURN_NONE; - PY_END(nullptr) + PY_BEGIN + if (self->is_bound()) { + size_t size = self->dims_.size(); + mpy::tuple t(size); + for(size_t i = 0; i < size; ++i) { + t.set(i, self->dims_[i]); + } + return mpy::repr(t).release(); + } else if(!mpy::is_none(self->name_)) { + return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); + } else { + return mpy::unicode_from_string("").release(); + } + PY_END(nullptr) +} + +static PyObject* DimList_bind(DimList *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + mpy::handle sizes; + static const char * const _keywords[] = {"sizes", nullptr}; + static _PyArg_Parser parser = {"O", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { + return nullptr; + } + if (!mpy::is_sequence(sizes)) { + mpy::raise_error(PyExc_ValueError, "expected a sequence"); + } + mpy::sequence_view seq = sizes; + auto size = seq.size(); + self->bind_len(size); + for (Py_ssize_t i = 0; i < size; ++i) { + self->dims_[i]->set_size(mpy::to_int(seq[i])); + } + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyObject* DimList_bind_len(DimList *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + int size; + static const char * const _keywords[] = {"N", nullptr}; + static _PyArg_Parser parser = {"i", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { + return nullptr; + } + self->bind_len(size); + Py_RETURN_NONE; + PY_END(nullptr) } static PyMethodDef DimList_methods[] = { - {"bind", (PyCFunction)(void*)DimList_bind, METH_FASTCALL | METH_KEYWORDS}, - {"bind_len", - (PyCFunction)(void*)DimList_bind_len, - METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, + {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; + static Py_ssize_t DimList_len(DimList* self) { - PY_BEGIN - return self->size(); - PY_END(-1) -} - -static PyObject* DimList_item(DimList* self, Py_ssize_t idx) { - PY_BEGIN - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - if (idx < 0 || (size_t)idx >= self->dims_.size()) { - mpy::raise_error(PyExc_IndexError, "index out of bounds"); - } - mpy::object r = self->dims_[idx]; - return r.release(); - PY_END(nullptr) -} - -PySequenceMethods DimList_seq{ - (lenfunc)DimList_len, // lenfunc sq_length; - 0, // binaryfunc sq_concat; - 0, // ssizeargfunc sq_repeat; - (ssizeargfunc)DimList_item, // ssizeargfunc sq_item; - 0, // void *was_sq_slice; - 0, // ssizeobjargproc sq_ass_item; - 0, // void *was_sq_ass_slice; - 0, // objobjproc sq_contains; - - 0, // binaryfunc sq_inplace_concat; - 0, // ssizeargfunc sq_inplace_repeat; + PY_BEGIN + return self->size(); + PY_END(-1) +} + +static PyObject * DimList_item(DimList* self, Py_ssize_t idx) { + PY_BEGIN + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + if (idx < 0 || (size_t) idx >= self->dims_.size()) { + mpy::raise_error(PyExc_IndexError, "index out of bounds"); + } + mpy::object r = self->dims_[idx]; + return r.release(); + PY_END(nullptr) +} + +PySequenceMethods DimList_seq { + (lenfunc) DimList_len, //lenfunc sq_length; + 0, //binaryfunc sq_concat; + 0, //ssizeargfunc sq_repeat; + (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; + 0, //void *was_sq_slice; + 0, //ssizeobjargproc sq_ass_item; + 0, //void *was_sq_ass_slice; + 0, //objobjproc sq_contains; + + 0, //binaryfunc sq_inplace_concat; + 0, //ssizeargfunc sq_inplace_repeat; }; static PyObject* DimList_getis_bound(DimList* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } static PyGetSetDef DimList_getsetters[] = { - {"is_bound", (getter)DimList_getis_bound, NULL, "is_bound", NULL}, - {NULL} /* Sentinel */ + {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL}, + {NULL} /* Sentinel */ }; + static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { - PY_BEGIN - if (mpy::is_int(idx)) { - return DimList_item(self, mpy::to_int(idx)); - } else if (mpy::is_slice(idx)) { - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - mpy::slice_view s(idx, self->dims_.size()); - mpy::tuple r(s.slicelength); - for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { - r.set(j++, self->dims_[i]); + PY_BEGIN + if (mpy::is_int(idx)) { + return DimList_item(self, mpy::to_int(idx)); + } else if (mpy::is_slice(idx)) { + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + mpy::slice_view s(idx, self->dims_.size()); + mpy::tuple r(s.slicelength); + for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { + r.set(j++, self->dims_[i]); + } + return r.release(); + } else { + mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); + return nullptr; } - return r.release(); - } else { - mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); - return nullptr; - } - PY_END(nullptr) + PY_END(nullptr) } PyMappingMethods DimList_mapping = { - 0, // lenfunc mp_length; - (binaryfunc)(void*)DimList_subscript, // binaryfunc mp_subscript; - 0, // objobjargproc mp_ass_subscript; + 0, //lenfunc mp_length; + (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; + 0, //objobjargproc mp_ass_subscript; }; + + PyTypeObject DimList::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.DimList", /* tp_name */ - sizeof(DimList), /* tp_basicsize */ - 0, /* tp_itemsize */ - DimList::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)DimList_repr, /* tp_repr */ - 0, /* tp_as_number */ - &DimList_seq, /* tp_as_sequence */ - &DimList_mapping, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - 0, /* tp_flags */ - "DimList Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - DimList_methods, /* tp_methods */ - 0, /* tp_members */ - DimList_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)DimList_init, /* tp_init */ - 0, /* tp_alloc */ - DimList::new_stub, /* tp_new */ + "_C.DimList", /* tp_name */ + sizeof(DimList), /* tp_basicsize */ + 0, /* tp_itemsize */ + DimList::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)DimList_repr, /* tp_repr */ + 0, /* tp_as_number */ + &DimList_seq, /* tp_as_sequence */ + &DimList_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + 0, /* tp_flags */ + "DimList Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DimList_methods, /* tp_methods */ + 0, /* tp_members */ + DimList_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc) DimList_init, /* tp_init */ + 0, /* tp_alloc */ + DimList::new_stub, /* tp_new */ }; -static int DimList_init(DimList* self, PyObject* args, PyObject* kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; - mpy::handle len_or_dims = nullptr; - PyObject* name = nullptr; - if (!PyArg_ParseTupleAndKeywords( - args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { - return -1; - } - self->init(mpy::object::borrow(name ? name : Py_None)); - if (len_or_dims.ptr()) { - if (mpy::is_int(len_or_dims)) { - self->bind_len(mpy::to_int(len_or_dims)); - } else if (mpy::is_sequence(len_or_dims)) { - mpy::sequence_view s(len_or_dims); - std::vector> dims; - size_t size = s.size(); - dims.reserve(size); - for (size_t i = 0; i < size; ++i) { - auto r = s[i]; - if (mpy::is_int(r)) { - dims.emplace_back(Dim::create( - mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), - mpy::to_int(r))); +static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; + mpy::handle len_or_dims = nullptr; + PyObject* name = nullptr; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { + return -1; + } + self->init(mpy::object::borrow(name ? name : Py_None)); + if (len_or_dims.ptr()) { + if(mpy::is_int(len_or_dims)) { + self->bind_len(mpy::to_int(len_or_dims)); + } else if (mpy::is_sequence(len_or_dims)) { + mpy::sequence_view s(len_or_dims); + std::vector> dims; + size_t size = s.size(); + dims.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto r = s[i]; + if (mpy::is_int(r)) { + dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r))); + } else { + dims.emplace_back(Dim::wrap(r)); + } + } + self->set_dims(std::move(dims)); } else { - dims.emplace_back(Dim::wrap(r)); + PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions"); + return -1; } - } - self->set_dims(std::move(dims)); - } else { - PyErr_Format( - PyExc_ValueError, "expected a length or a sequence of dimensions"); - return -1; + return 0; } return 0; - } - return 0; - PY_END(-1); + PY_END(-1); } // Tensor ----------------------------- PyTypeObject* TensorType = nullptr; // the python wrapper type. -mpy::object run_torch_function( - Arena& A, - mpy::handle orig, - mpy::vector_args args, - bool is_pointwise); +mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise); -namespace { +namespace{ at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { - auto levels = Slice(); - levels.extend(A, levels_); - while (true) { - int64_t min_real_index = -1; - int64_t min_index = -1; - int64_t min_value = INT_MAX; - int64_t i = 0; - int64_t r = 0; - for (auto l : levels) { - if (!l.is_none()) { - if (!l.is_positional() && l.dim()->level_ < min_value) { - min_value = l.dim()->level_; - min_index = i; - min_real_index = r; + auto levels = Slice(); + levels.extend(A, levels_); + while (true) { + int64_t min_real_index = -1; + int64_t min_index = -1; + int64_t min_value = INT_MAX; + int64_t i = 0; + int64_t r = 0; + for (auto l : levels) { + if (!l.is_none()) { + if (!l.is_positional() && l.dim()->level_ < min_value) { + min_value = l.dim()->level_; + min_index = i; + min_real_index = r; + } + ++i; + } + ++r; } - ++i; - } - ++r; - } - if (min_index == -1) { - return t; + if (min_index == -1) { + return t; + } + auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); + t = std::move(t2); + levels[min_real_index] = DimEntry(); } - auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); - t = std::move(t2); - levels[min_real_index] = DimEntry(); - } } + + struct DelayedOperator { - DelayedOperator(mpy::object o, mpy::vector_args a) - : orig(std::move(o)), args(a) { - auto all = a.size(); - // this will outlive the call so - // take ownership of temporaries - // in vector args - auto buf = new mpy::handle[all]; - memcpy(buf, args.args, sizeof(mpy::handle) * all); - args.args = buf; - for (auto i : args.enumerate_all()) { - Py_INCREF(args.args[i].ptr()); - } - Py_XINCREF(args.kwnames.ptr()); - } - ~DelayedOperator() { - for (auto i : args.enumerate_all()) { - Py_DECREF(args[i].ptr()); - } - if (args.has_keywords()) { - Py_XDECREF(args.kwnames.ptr()); - } - delete[] args.args; - } - mpy::object orig; - mpy::vector_args args; + DelayedOperator(mpy::object o, mpy::vector_args a) + : orig(std::move(o)), args(a) { + auto all = a.size(); + // this will outlive the call so + // take ownership of temporaries + // in vector args + auto buf = new mpy::handle[all]; + memcpy(buf, args.args, sizeof(mpy::handle)*all); + args.args = buf; + for (auto i : args.enumerate_all()) { + Py_INCREF(args.args[i].ptr()); + } + Py_XINCREF(args.kwnames.ptr()); + } + ~DelayedOperator() { + for (auto i : args.enumerate_all()) { + Py_DECREF(args[i].ptr()); + } + if (args.has_keywords()) { + Py_XDECREF(args.kwnames.ptr()); + } + delete [] args.args; + } + mpy::object orig; + mpy::vector_args args; }; void free_levels_dims(Slice levels) { - for (auto e : levels) { - if (!e.is_positional()) { - mpy::object::steal(e.dim()); + for(auto e : levels) { + if (!e.is_positional()) { + mpy::object::steal(e.dim()); + } } - } } -} // namespace +} struct Tensor : public mpy::base { - private: - at::Tensor tensor_; - at::Tensor batchtensor_; - OwnedSlice levels_; - bool has_device_; - std::unique_ptr delayed_; - - public: - at::Tensor& tensor(Arena& A) { - if (C10_UNLIKELY(!tensor_.defined())) { - AT_ASSERT(delayed_); - auto t = Tensor::wrap( - run_torch_function(A, delayed_->orig, delayed_->args, true)); - tensor_ = t->tensor(A); - delayed_.reset(); - // don't force creation of batch tensor if it wasn't already provided. - batchtensor_ = t->batchtensor_; - AT_ASSERT(levels() == t->levels()); - } - return tensor_; - } - at::Tensor& batchtensor(Arena& A) { - if (C10_UNLIKELY(!batchtensor_.defined())) { - batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); - } - return batchtensor_; - } - Slice levels() { - return levels_.slice(); - } - bool has_device() { - return has_device_; - } - DelayedOperator* delayed() { - return delayed_.get(); - } - static PyTypeObject Type; - - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == TensorType; - } - - static mpy::obj create() { - if (!TensorType) { - TensorType = - (PyTypeObject*)mpy::import("functorch.dim").attr("Tensor").release(); - } - return Tensor::alloc(TensorType); - } - void capture_levels(Slice levels) { - // grab ownership of the dims inside levels - for (auto l : levels) { - if (!l.is_positional()) { - mpy::object::borrow(l.dim()).release(); - } - } - levels_.set(levels, free_levels_dims); - } - static mpy::object from_positional( - Arena& A, - at::Tensor tensor, - Slice levels, - bool has_device); - static mpy::obj create_delayed( - mpy::object op, - mpy::vector_args args, - Slice levels, - bool has_device); - friend struct EnableAllLayers; +private: + at::Tensor tensor_; + at::Tensor batchtensor_; + OwnedSlice levels_; + bool has_device_; + std::unique_ptr delayed_; +public: + + at::Tensor& tensor(Arena& A) { + if (C10_UNLIKELY(!tensor_.defined())) { + AT_ASSERT(delayed_); + auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); + tensor_ = t->tensor(A); + delayed_.reset(); + // don't force creation of batch tensor if it wasn't already provided. + batchtensor_ = t->batchtensor_; + AT_ASSERT(levels() == t->levels()); + } + return tensor_; + } + at::Tensor& batchtensor(Arena& A) { + if (C10_UNLIKELY(!batchtensor_.defined())) { + batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); + } + return batchtensor_; + } + Slice levels() { + return levels_.slice(); + } + bool has_device() { + return has_device_; + } + DelayedOperator* delayed() { + return delayed_.get(); + } + static PyTypeObject Type; + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == TensorType; + } + + + static mpy::obj create() { + if (!TensorType) { + TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").release(); + } + return Tensor::alloc(TensorType); + } + void capture_levels(Slice levels) { + // grab ownership of the dims inside levels + for (auto l : levels) { + if (!l.is_positional()) { + mpy::object::borrow(l.dim()).release(); + } + } + levels_.set(levels, free_levels_dims); + } + static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device); + static mpy::obj create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device); + friend struct EnableAllLayers; }; -namespace { +namespace{ // version in header does a unnecessary refcount +/- -at::functorch::BatchedTensorImpl* maybeGetBatchedImpl( - const at::Tensor& tensor) { - if (at::functorch::isBatchedTensor(tensor)) { - return static_cast( - tensor.unsafeGetTensorImpl()); - } - return nullptr; +at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { + if (at::functorch::isBatchedTensor(tensor)) { + return static_cast(tensor.unsafeGetTensorImpl()); + } + return nullptr; } TensorRef unchecked_tensor_from(mpy::handle p) { - auto v = (THPVariable*)p.ptr(); - return TensorRef(*v->cdata); + auto v = (THPVariable*) p.ptr(); + return TensorRef(*v->cdata); } static int64_t ndim_of_levels(Slice levels) { - int64_t r = 0; - for (auto l : levels) { - if (l.is_positional()) { - ++r; + int64_t r = 0; + for (auto l : levels) { + if (l.is_positional()) { + ++r; + } } - } - return r; + return r; } struct TensorInfo { - TensorRef tensor; - Slice levels; - bool has_device; - TensorRef batchedtensor; - int64_t ndim() const { - return ndim_of_levels(levels); - } - operator bool() const { - return tensor; - } - - static TensorInfo create( - Arena& A, - mpy::handle h, - bool ensure_batched = true, - bool ensure_present = true) { - if (Tensor::check_exact(h)) { - auto t = Tensor::unchecked_wrap(h); - return TensorInfo{ - t->tensor(A), - t->levels(), - t->has_device(), - ensure_batched ? t->batchtensor(A) : TensorRef()}; - } else if (Dim::check_exact(h)) { - auto d = Dim::unchecked_wrap(h); - return TensorInfo{ - d->range(), - Slice(A, DimEntry(d)), - false, - ensure_batched ? d->batchtensor() : TensorRef()}; - } else if (THPVariable_Check(h.ptr())) { - TensorRef t = unchecked_tensor_from(h); - Slice levels; - for (auto i : irange(-t->dim(), 0)) { - levels.append(A, i); - } - return TensorInfo{t, levels, true, t}; - } else { - if (ensure_present) { - mpy::raise_error(PyExc_ValueError, "expected a tensor object"); - } - return TensorInfo{}; + TensorRef tensor; + Slice levels; + bool has_device; + TensorRef batchedtensor; + int64_t ndim() const { + return ndim_of_levels(levels); + } + operator bool() const { + return tensor; + } + + static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) { + if (Tensor::check_exact(h)) { + auto t = Tensor::unchecked_wrap(h); + return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; + } else if (Dim::check_exact(h)) { + auto d = Dim::unchecked_wrap(h); + return TensorInfo {d->range(), Slice(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; + } else if (THPVariable_Check(h.ptr())) { + TensorRef t = unchecked_tensor_from(h); + Slice levels; + for (auto i : irange(-t->dim(), 0)) { + levels.append(A, i); + } + return TensorInfo {t, levels, true, t}; + } else { + if (ensure_present) { + mpy::raise_error(PyExc_ValueError, "expected a tensor object"); + } + return TensorInfo {}; + } } - } + + }; -static PyObject* py_Tensor_from_positional( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN -#define ARGS(_) \ - _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) - MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) -#undef ARGS - - if (!THPVariable_Check(tensor.ptr())) { - mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); - } - - Slice levels; - mpy::sequence_view sq(py_levels); - for (auto i : sq.enumerate()) { - mpy::object v = sq[i]; - if (mpy::is_int(v)) { - auto vi = mpy::to_int(v); - levels.append(A, vi); - } else { - auto dim = Dim::wrap(std::move(v)); - mpy::hdl hdim = dim; - levels.append(A, hdim); - } - } - return Tensor::from_positional( - A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0) - .release(); - PY_END(nullptr) -} -} // namespace - -mpy::object Tensor::from_positional( - Arena& A, - at::Tensor tensor, - Slice levels, - bool has_device) { - size_t seen_dims = 0; - int last = 0; - // auto sz = tensor.sizes(); - for (auto i : levels.enumerate()) { - auto l = levels[i]; - if (l.is_positional()) { - AT_ASSERT(last == 0 || last + 1 == l.position()); - last = l.position(); - } else { - mpy::object::borrow(l.dim()).release(); - // AT_ASSERT(sz[i] == l.dim()->size()); - ++seen_dims; - } - } - AT_ASSERT(last == 0 || last == -1); - if (!seen_dims) { - return mpy::object::steal(THPVariable_Wrap(tensor)); - } - - mpy::obj self = Tensor::create(); - self->tensor_ = std::move(tensor); - AT_ASSERT(self->tensor_.dim() == levels.size()); - self->levels_.set(levels, free_levels_dims); - self->has_device_ = has_device; - mpy::object r = std::move(self); - return r; -} - -mpy::obj Tensor::create_delayed( - mpy::object op, - mpy::vector_args args, - Slice levels, - bool has_device) { - mpy::obj self = Tensor::create(); - self->capture_levels(levels); - self->has_device_ = has_device; - self->delayed_ = std::make_unique(std::move(op), args); - return self; -} - -namespace { +static PyObject* py_Tensor_from_positional(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) + MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) + #undef ARGS + + if (!THPVariable_Check(tensor.ptr())) { + mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); + } + + Slice levels; + mpy::sequence_view sq(py_levels); + for (auto i : sq.enumerate()) { + mpy::object v = sq[i]; + if (mpy::is_int(v)) { + auto vi = mpy::to_int(v); + levels.append(A, vi); + } else { + auto dim = Dim::wrap(std::move(v)); + mpy::hdl hdim = dim; + levels.append(A, hdim); + } + } + return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); + PY_END(nullptr) +} +} + +mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device) { + size_t seen_dims = 0; + int last = 0; + //auto sz = tensor.sizes(); + for (auto i : levels.enumerate()) { + auto l = levels[i]; + if (l.is_positional()) { + AT_ASSERT(last == 0 || last + 1 == l.position()); + last = l.position(); + } else { + mpy::object::borrow(l.dim()).release(); + //AT_ASSERT(sz[i] == l.dim()->size()); + ++seen_dims; + } + } + AT_ASSERT(last == 0 || last == -1); + if (!seen_dims) { + return mpy::object::steal(THPVariable_Wrap(tensor)); + } + + mpy::obj self = Tensor::create(); + self->tensor_ = std::move(tensor); + AT_ASSERT(self->tensor_.dim() == levels.size()); + self->levels_.set(levels, free_levels_dims); + self->has_device_ = has_device; + mpy::object r = std::move(self); + return r; +} + + +mpy::obj Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device) { + mpy::obj self = Tensor::create(); + self->capture_levels(levels); + self->has_device_ = has_device; + self->delayed_ = std::make_unique(std::move(op), args); + return self; +} + +namespace{ mpy::list slice_to_list(Slice h) { - mpy::list lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::list lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } mpy::tuple slice_to_tuple(Slice h) { - mpy::tuple lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::tuple lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } enum UType { - U_ELEM, - U_TUPLE_LIKE, - U_DICT, + U_ELEM, + U_TUPLE_LIKE, + U_DICT, }; struct Unflatten { - mpy::object operator()(Slice& elements) { - mpy::object r; - switch (type) { - case U_ELEM: { - r = mpy::object::borrow(elements[0]); - elements = elements.slice(1); - } break; - case U_TUPLE_LIKE: { - mpy::tuple tup(children.size()); - for (auto i : children.enumerate()) { - tup.set(i, children[i](elements)); - } - r = obj.call(tup); - } break; - case U_DICT: { - r = mpy::object::checked_steal(PyDict_New()); - mpy::dict_view rv(r); - mpy::dict_view d(obj); - Py_ssize_t pos = 0; - mpy::handle k, v; - for (int i = 0; d.next(&pos, &k, &v); ++i) { - rv.set(k, children[i](elements)); + mpy::object operator()(Slice& elements) { + mpy::object r; + switch (type) { + case U_ELEM: { + r = mpy::object::borrow(elements[0]); + elements = elements.slice(1); + } break; + case U_TUPLE_LIKE: { + mpy::tuple tup(children.size()); + for (auto i : children.enumerate()) { + tup.set(i, children[i](elements)); + } + r = obj.call(tup); + } break; + case U_DICT: { + r = mpy::object::checked_steal(PyDict_New()); + mpy::dict_view rv(r); + mpy::dict_view d(obj); + Py_ssize_t pos = 0; + mpy::handle k, v; + for (int i = 0; d.next(&pos, &k, &v); ++i) { + rv.set(k, children[i](elements)); + } + } break; } - } break; + return r; } - return r; - } - UType type; - mpy::handle obj; - Slice children; + UType type; + mpy::handle obj; + Slice children; }; -Unflatten tree_flatten( - Arena& A, - mpy::handle agg, - Slice& flat_elements) { - Slice c; - UType utype; - mpy::handle obj; - if (mpy::list_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - mpy::list_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::tuple_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - // includes named tuples - mpy::tuple_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::dict_view::check(agg)) { - utype = U_DICT; - mpy::dict_view d(agg); - obj = agg; - Py_ssize_t pos = 0; - mpy::handle k, v; - while (d.next(&pos, &k, &v)) { - c.append(A, tree_flatten(A, v, flat_elements)); +Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice& flat_elements) { + Slice c; + UType utype; + mpy::handle obj; + if (mpy::list_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + mpy::list_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::tuple_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + // includes named tuples + mpy::tuple_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::dict_view::check(agg)) { + utype = U_DICT; + mpy::dict_view d(agg); + obj = agg; + Py_ssize_t pos = 0; + mpy::handle k, v; + while (d.next(&pos, &k, &v)) { + c.append(A, tree_flatten(A, v, flat_elements)); + } + } else { + utype = U_ELEM; + flat_elements.append(A, agg); } - } else { - utype = U_ELEM; - flat_elements.append(A, agg); - } - return Unflatten{utype, obj, c}; + return Unflatten {utype, obj, c}; } struct UnflattenVectorArgs { - mpy::vector_args operator()(Arena& A, Slice& elements) { - if (!had_nested) { - auto args = elements.begin(); - elements = Slice(); - return mpy::vector_args(args, nargs, kwnames); - } - Slice args; - for (auto u : children) { - args.append(A, A.autorelease(u(elements))); - } - return mpy::vector_args(args.begin(), nargs, kwnames); - } - Slice children; - Py_ssize_t nargs; - mpy::handle kwnames; - bool had_nested; + mpy::vector_args operator()(Arena& A, Slice& elements) { + if (!had_nested) { + auto args = elements.begin(); + elements = Slice(); + return mpy::vector_args(args, nargs, kwnames); + } + Slice args; + for (auto u : children) { + args.append(A, A.autorelease(u(elements))); + } + return mpy::vector_args(args.begin(), nargs, kwnames); + } + Slice children; + Py_ssize_t nargs; + mpy::handle kwnames; + bool had_nested; }; -UnflattenVectorArgs tree_flatten( - Arena& A, - mpy::vector_args args, - Slice& flat_elements) { - UnflattenVectorArgs r; - r.kwnames = args.kwnames; - r.nargs = args.nargs; - r.had_nested = false; - auto N = args.size(); - for (auto i : irange(N)) { - auto typ = Py_TYPE(args[i].ptr()); - // fast checks that this thing isn't something that is nested. - bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || - typ == TensorType || typ == DimType; - if (!is_element) { - flat_elements.extend(A, args.args, args.args + i); - for (auto j : irange(i)) { - (void)j; - r.children.append(A, Unflatten{U_ELEM}); - } - for (auto j : irange(i, N)) { - r.children.append(A, tree_flatten(A, args[j], flat_elements)); - if (r.children.back().type != U_ELEM) { - r.had_nested = true; - } - } - return r; - } - } - flat_elements.extend(A, args.args, args.args + N); - return r; +UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice& flat_elements) { + UnflattenVectorArgs r; + r.kwnames = args.kwnames; + r.nargs = args.nargs; + r.had_nested = false; + auto N = args.size(); + for(auto i : irange(N)) { + auto typ = Py_TYPE(args[i].ptr()); + // fast checks that this thing isn't something that is nested. + bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; + if (!is_element) { + flat_elements.extend(A, args.args, args.args + i); + for (auto j : irange(i)) { + (void)j; + r.children.append(A, Unflatten {U_ELEM}); + } + for (auto j : irange(i, N)) { + r.children.append(A, tree_flatten(A, args[j], flat_elements)); + if (r.children.back().type != U_ELEM) { + r.had_nested = true; + } + } + return r; + } + } + flat_elements.extend(A, args.args, args.args + N); + return r; } + struct UnflattenArena { - Arena A; - Unflatten unflatten; + Arena A; + Unflatten unflatten; }; -PyObject* py_unflatten( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN -#define ARGS(_) _(mpy::handle, ns) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) -#undef ARGS - mpy::sequence_view sv(ns); - // because we do not have a autorelase pool yet... - Arena A; - Slice slice; - mpy::handle Tuple = (PyObject*)&PyTuple_Type; - auto inputs = Tuple.call(ns); - mpy::tuple_view tv(inputs); - for (auto i : tv.enumerate()) { - slice.append(A, tv[i]); - } - auto AA = (UnflattenArena*)PyCapsule_GetPointer(self, "arena"); - auto r = AA->unflatten(slice).release(); - AT_ASSERT(r != nullptr); - return r; - PY_END(nullptr) -} - -PyMethodDef py_unflatten_def = { - "unflatten", - (PyCFunction)(void*)py_unflatten, - METH_FASTCALL | METH_KEYWORDS}; - -void free_unflatten_arena(PyObject* pc) { - delete (UnflattenArena*)PyCapsule_GetPointer(pc, "arena"); -} - -PyObject* py_tree_flatten( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN -#define ARGS(_) _(mpy::handle, tree) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) -#undef ARGS - auto A = new UnflattenArena; - Slice elements; - A->unflatten = tree_flatten(A->A, tree, elements); - auto cap = mpy::object::checked_steal( - PyCapsule_New(A, "arena", free_unflatten_arena)); - auto unflatten = mpy::object::checked_steal( - PyCFunction_New(&py_unflatten_def, cap.release())); - mpy::tuple r(2); - r.set(0, slice_to_list(elements)); - r.set(1, std::move(unflatten)); - return r.release(); - PY_END(nullptr) -} - -mpy::object tree_map( - Arena& A, - const std::function& fn, - mpy::handle agg) { - Slice elements; - auto unflatten = tree_flatten(A, agg, elements); - for (auto i : elements.enumerate()) { - elements[i] = fn(elements[i]); - } - return unflatten(elements); +PyObject* py_unflatten(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + #define ARGS(_) _(mpy::handle, ns) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) + #undef ARGS + mpy::sequence_view sv(ns); + // because we do not have a autorelase pool yet... + Arena A; + Slice slice; + mpy::handle Tuple = (PyObject*) &PyTuple_Type; + auto inputs = Tuple.call(ns); + mpy::tuple_view tv(inputs); + for (auto i : tv.enumerate()) { + slice.append(A, tv[i]); + } + auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena"); + auto r = AA->unflatten(slice).release(); + AT_ASSERT(r != nullptr); + return r; + PY_END(nullptr) +} + +PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; + +void free_unflatten_arena(PyObject * pc) { + delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena"); +} + +PyObject* py_tree_flatten(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + #define ARGS(_) _(mpy::handle, tree) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) + #undef ARGS + auto A = new UnflattenArena; + Slice elements; + A->unflatten = tree_flatten(A->A, tree, elements); + auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena)); + auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); + mpy::tuple r(2); + r.set(0, slice_to_list(elements)); + r.set(1, std::move(unflatten)); + return r.release(); + PY_END(nullptr) +} + + + +mpy::object tree_map(Arena& A, const std::function& fn, mpy::handle agg) { + Slice elements; + auto unflatten = tree_flatten(A, agg, elements); + for (auto i : elements.enumerate()) { + elements[i] = fn(elements[i]); + } + return unflatten(elements); } // prereq: isinstance(h, _Tensor) int64_t _Tensor_ndim(mpy::handle h) { - if (Tensor::check(h)) { - int64_t r = 0; - for (auto l : Tensor::unchecked_wrap(h)->levels()) { - if (l.is_positional()) { - ++r; - } + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } + } + return r; } - return r; - } - // Dim or DelayedMulTensor - return 0; + // Dim or DelayedMulTensor + return 0; } mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && - !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} -} // namespace + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} +} struct EnableAllLayers { - EnableAllLayers(Arena& A, Slice levels) { - std::vector> layers; - layers.reserve(levels.size()); - for (auto l : levels) { - if (!l.is_positional()) { - auto d = l.dim(); - levels_to_dim_.append(A, d); - } - } - std::sort( - levels_to_dim_.begin(), - levels_to_dim_.end(), - [](mpy::hdl lhs, mpy::hdl rhs) { - return lhs->level_ < rhs->level_; - }); - - for (auto i : levels_to_dim_.enumerate()) { - auto batch_size = levels_to_dim_[i]->size(); - auto level = at::functorch::initAndPushDynamicLayer( - at::functorch::TransformType::Vmap, - batch_size, - at::functorch::RandomnessType::Different); - if (i == 0) { - levels_start_ = level; - } - } - } - - ~EnableAllLayers() { - auto to_remove = levels_start_ + levels_to_dim_.size() - 1; - for (auto i : levels_to_dim_.enumerate()) { - AT_ASSERT( - at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == - to_remove - i); - } - } - - mpy::obj from_batched( - Arena& A, - at::Tensor batchedtensor, - bool has_device) { - Slice levels; - for (auto i : irange(-batchedtensor.dim(), 0)) { - levels.append(A, i); + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl lhs, mpy::hdl rhs) { return lhs->level_ < rhs->level_;}); + + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } } - TensorRef tensor; - at::functorch::BatchedTensorImpl* impl = maybeGetBatchedImpl(batchedtensor); - while (true) { - auto level = impl->level(); - AT_ASSERT( - level >= levels_start_ && - level < levels_start_ + levels_to_dim_.size()); - mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); - levels.insert(A, impl->bdim(), dim); - at::functorch::BatchedTensorImpl* nimpl = - maybeGetBatchedImpl(impl->value()); - if (!nimpl) { - tensor = impl->value(); - break; - } - impl = nimpl; + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); + } } - mpy::obj self = Tensor::create(); - // grab ownership of the tensors - self->tensor_ = *tensor; - self->batchtensor_ = std::move(batchedtensor); - self->has_device_ = has_device; - self->capture_levels(levels); - return self; - } - void inplace_update_layers(TensorRef batchtensor, Slice levels) { - // XXX - requires a patch to functorch to att set_level - auto impl = maybeGetBatchedImpl(*batchtensor); - for (auto i : levels_to_dim_.reversed_enumerate()) { - if (!impl) { - break; - } - if (levels.contains(levels_to_dim_[i])) { - impl->_unsafe_set_level(levels_start_ + i); - impl = maybeGetBatchedImpl(impl->value()); - } - } - } - - private: - int64_t levels_start_{}; - Slice> levels_to_dim_; + mpy::obj from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); + } + TensorRef tensor; + at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); + while(true) { + auto level = impl->level(); + AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); + mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); + levels.insert(A, impl->bdim(), dim); + at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); + if (!nimpl) { + tensor = impl->value(); + break; + } + impl = nimpl; + } + + mpy::obj self = Tensor::create(); + // grab ownership of the tensors + self->tensor_ = *tensor; + self->batchtensor_ = std::move(batchedtensor); + self->has_device_ = has_device; + self->capture_levels(levels); + return self; + } + void inplace_update_layers(TensorRef batchtensor, Slice levels) { + // XXX - requires a patch to functorch to att set_level + auto impl = maybeGetBatchedImpl(*batchtensor); + for (auto i : levels_to_dim_.reversed_enumerate()) { + if (!impl) { + break; + } + if (levels.contains(levels_to_dim_[i])) { + impl->_unsafe_set_level(levels_start_ + i); + impl = maybeGetBatchedImpl(impl->value()); + + } + } + } +private: + int64_t levels_start_{}; + Slice> levels_to_dim_; }; -namespace { -TensorRef _match_levels( - Arena& A, - TensorRef v, - Slice from_levels, - Slice to_levels, - bool drop_levels = false) { - if (from_levels == to_levels) { - return v; - } - // drop_levels -> if a dim appears in from_levels but not to_levels, it is - // assumed it has stride 0. - at::IntArrayRef sz = v->sizes(); - at::IntArrayRef sd = v->strides(); - AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); - Slice nsz; - Slice nsd; - for (auto l : to_levels) { - auto oidx = from_levels.index(l); - if (!oidx) { - nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); - nsd.append(A, 0); +namespace{ +TensorRef _match_levels(Arena& A, TensorRef v, Slice from_levels, Slice to_levels, bool drop_levels=false) { + if (from_levels == to_levels) { + return v; + } + // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. + at::IntArrayRef sz = v->sizes(); + at::IntArrayRef sd = v->strides(); + AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); + Slice nsz; + Slice nsd; + for (auto l : to_levels) { + auto oidx = from_levels.index(l); + if (!oidx) { + nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); + nsd.append(A, 0); + } else { + auto idx = *oidx; + nsz.append(A, sz[idx]); + nsd.append(A, sd[idx]); + } + } + return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); +} +} +mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { + if (!pointwise_optimize) { + is_pointwise = false; + } + // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; + + Slice> all_dims; + Slice flat_args; + auto unflatten_args = tree_flatten(A, args, flat_args); + TensorRef device_holding_tensor; + + Slice infos; + Slice result_levels; + for (auto f : flat_args) { + infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); + if (infos.back()) { + TensorInfo& info = infos.back(); + AT_ASSERT(is_pointwise || info.batchedtensor); + if (!device_holding_tensor && info.has_device) { + device_holding_tensor = infos.back().tensor; + } + for (auto l : info.levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + } + + if (is_pointwise) { + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef tensor = infos[i].tensor; + if (device_holding_tensor && !infos[i].has_device) { + tensor = A.autorelease(tensor->to(device_holding_tensor->device())); + } + auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); + flat_args[i] = handle_from_tensor(A, std::move(ml)); + } + } + + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + + mpy::object result = orig.call_vector(uargs); + + // fast wrap for normal case where operator just returns a tensor. + if (THPVariable_Check(result.ptr())) { + return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); + } + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())){ + return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); + } + return h; + }; + return tree_map(A, wrap, result); } else { - auto idx = *oidx; - nsz.append(A, sz[idx]); - nsd.append(A, sd[idx]); - } - } - return A.autorelease(v->as_strided( - at::IntArrayRef(nsz.begin(), nsz.end()), - at::IntArrayRef(nsd.begin(), nsd.end()), - v->storage_offset())); -} -} // namespace -mpy::object run_torch_function( - Arena& A, - mpy::handle orig, - mpy::vector_args args, - bool is_pointwise) { - if (!pointwise_optimize) { - is_pointwise = false; - } - // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : - // "functorch") << " " << orig << "\n"; - - Slice> all_dims; - Slice flat_args; - auto unflatten_args = tree_flatten(A, args, flat_args); - TensorRef device_holding_tensor; - - Slice infos; - Slice result_levels; - for (auto f : flat_args) { - infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); - if (infos.back()) { - TensorInfo& info = infos.back(); - AT_ASSERT(is_pointwise || info.batchedtensor); - if (!device_holding_tensor && info.has_device) { - device_holding_tensor = infos.back().tensor; - } - for (auto l : info.levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } - } - } - - if (is_pointwise) { - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef tensor = infos[i].tensor; - if (device_holding_tensor && !infos[i].has_device) { - tensor = A.autorelease(tensor->to(device_holding_tensor->device())); - } - auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); - flat_args[i] = handle_from_tensor(A, std::move(ml)); - } - } - - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - - mpy::object result = orig.call_vector(uargs); - - // fast wrap for normal case where operator just returns a tensor. - if (THPVariable_Check(result.ptr())) { - return Tensor::from_positional( - A, - THPVariable_Unpack(result.ptr()), - result_levels, - device_holding_tensor); + // std::cout << orig << " calling functorch...\n"; + // std::cout << "rl: " << result_levels << "\n"; + EnableAllLayers guard(A, result_levels); + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef batched = infos[i].batchedtensor; + if (device_holding_tensor && !infos[i].has_device) { + batched = A.autorelease(batched->to(device_holding_tensor->device())); + } + guard.inplace_update_layers(batched, infos[i].levels); + flat_args[i] = handle_from_tensor(A, batched); + } + } + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + AT_ASSERT(flat_it.size() == 0); + mpy::object result = orig.call_vector(uargs); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); + } + return h; + }; + if (THPVariable_Check(result.ptr())) { + return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); + } + return tree_map(A, wrap, result); + } +} + +namespace{ + +mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { + if (orig == torch_Tensor___mul__) { + AT_ASSERT(args.nargs == 2 && !args.has_keywords()); + auto lhs = args[0]; + auto rhs = args[1]; + if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { + bool has_device = false; + Slice levels; + for (auto i : args.enumerate_positional()) { + auto t = TensorInfo::create(A, args[i], false); + // something like a mask * rhs, which matrix multiplies don't correctly promote + if (!t.tensor->is_floating_point()) { + return run_torch_function(A, orig, args, is_pointwise); + } + has_device = has_device || t.has_device; + for (auto l : t.levels) { + if (!levels.contains(l)) { + levels.append(A, l); + } + } + } + // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; + return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device); + } } - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(Tensor::from_positional( - A, - THPVariable_Unpack(h.ptr()), - result_levels, - device_holding_tensor)); - } - return h; - }; - return tree_map(A, wrap, result); - } else { - // std::cout << orig << " calling functorch...\n"; - // std::cout << "rl: " << result_levels << "\n"; - EnableAllLayers guard(A, result_levels); - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef batched = infos[i].batchedtensor; - if (device_holding_tensor && !infos[i].has_device) { - batched = A.autorelease(batched->to(device_holding_tensor->device())); - } - guard.inplace_update_layers(batched, infos[i].levels); - flat_args[i] = handle_from_tensor(A, batched); - } - } - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - AT_ASSERT(flat_it.size() == 0); - mpy::object result = orig.call_vector(uargs); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(guard.from_batched( - A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); - } - return h; - }; - if (THPVariable_Check(result.ptr())) { - return guard.from_batched( - A, THPVariable_Unpack(result.ptr()), device_holding_tensor); - } - return tree_map(A, wrap, result); - } -} - -namespace { - -mpy::object __torch_function__( - Arena& A, - mpy::handle orig, - mpy::vector_args args, - bool is_pointwise) { - if (orig == torch_Tensor___mul__) { - AT_ASSERT(args.nargs == 2 && !args.has_keywords()); - auto lhs = args[0]; - auto rhs = args[1]; - if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && - _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { - bool has_device = false; - Slice levels; - for (auto i : args.enumerate_positional()) { - auto t = TensorInfo::create(A, args[i], false); - // something like a mask * rhs, which matrix multiplies don't correctly - // promote - if (!t.tensor->is_floating_point()) { - return run_torch_function(A, orig, args, is_pointwise); - } - has_device = has_device || t.has_device; - for (auto l : t.levels) { - if (!levels.contains(l)) { - levels.append(A, l); - } - } - } - // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; - return Tensor::create_delayed( - mpy::object::borrow(orig), args, levels, has_device); - } - } - return run_torch_function(A, orig, args, is_pointwise); -} - -mpy::vector_args as_vector_args( - Arena& A, - mpy::handle args, - mpy::handle kwargs) { - auto pos_args = (mpy::handle*)&PyTuple_GET_ITEM(args.ptr(), 0); - auto pos_n = PyTuple_GET_SIZE(args.ptr()); - if (!kwargs.ptr()) { - return mpy::vector_args(pos_args, pos_n, nullptr); - } - Slice all_args; - Slice kwnames; - all_args.extend(A, pos_args, pos_args + pos_n); - mpy::dict_view dv(kwargs); - Py_ssize_t pos = 0; - mpy::handle key, value; - while (dv.next(&pos, &key, &value)) { - all_args.append(A, value); - kwnames.append(A, key); - } - return mpy::vector_args( - all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); -} - -PyObject* py___torch_function__( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - AT_ASSERT(nargs == 4 || nargs == 5); - auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); - bool is_pointwise = pointwise.contains(args[1]); - return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); - PY_END(nullptr) + return run_torch_function(A, orig, args, is_pointwise); +} + +mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) { + auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); + auto pos_n = PyTuple_GET_SIZE(args.ptr()); + if (!kwargs.ptr()) { + return mpy::vector_args(pos_args, pos_n, nullptr); + } + Slice all_args; + Slice kwnames; + all_args.extend(A, pos_args, pos_args + pos_n); + mpy::dict_view dv(kwargs); + Py_ssize_t pos = 0; + mpy::handle key, value; + while (dv.next(&pos, &key, &value)) { + all_args.append(A, value); + kwnames.append(A, key); + } + return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); +} + +PyObject* py___torch_function__(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + AT_ASSERT(nargs == 4 || nargs == 5); + auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); + bool is_pointwise = pointwise.contains(args[1]); + return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); + PY_END(nullptr) } mpy::object levels_to_tuple(Slice slice) { - mpy::tuple t(slice.size()); - for (auto i : slice.enumerate()) { - t.set( - i, - slice[i].is_positional() ? mpy::from_int(slice[i].position()) - : mpy::object::borrow(slice[i].dim())); - } - mpy::object r = std::move(t); - return r; + mpy::tuple t(slice.size()); + for (auto i : slice.enumerate()) { + t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim())); + } + mpy::object r = std::move(t); + return r; } PyObject* Tensor_ndim(Tensor* self, void*) { - Py_ssize_t i = 0; - for (auto l : self->levels()) { - if (l.is_positional()) { - ++i; + Py_ssize_t i = 0; + for (auto l : self->levels()) { + if (l.is_positional()) { + ++i; + } } - } - return mpy::from_int(i).release(); + return mpy::from_int(i).release(); } PyGetSetDef Tensor_getsetters[] = { - {"_has_device", - (getter)[](PyObject* self, void*) - ->PyObject* { - return mpy::from_bool(((Tensor*)self)->has_device()).release(); -} // namespace -, NULL -} -, {"_tensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; -return THPVariable_Wrap(((Tensor*)self)->tensor(A)); -} -, NULL -} -, {"_batchtensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; -return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); -} -, NULL -} -, - {"_levels", - (getter)[](PyObject* self, void*) - ->PyObject* {PY_BEGIN return levels_to_tuple(((Tensor*)self)->levels()) - .release(); -PY_END(nullptr) -} -} -, {"ndim", (getter)Tensor_ndim, NULL, "ndim", NULL}, { - NULL -} /* Sentinel */ -} -; + {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, + {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* { + Arena A; + return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, + {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* { + Arena A; + return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, + {"_levels", (getter) [](PyObject* self, void*) -> PyObject* { + PY_BEGIN + return levels_to_tuple(((Tensor*)self)->levels()).release(); + PY_END(nullptr) + }}, + {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL}, + {NULL} /* Sentinel */ +}; PyMethodDef Tensor_methods[] = { - {NULL, NULL, 0, NULL} /* Sentinel */ + {NULL, NULL, 0, NULL} /* Sentinel */ }; } + PyTypeObject Tensor::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Tensor", /* tp_name */ - sizeof(Tensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - Tensor::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - "Tensor Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - Tensor_methods, /* tp_methods */ - 0, /* tp_members */ - Tensor_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - Tensor::new_stub, /* tp_new */ + "_C.Tensor", /* tp_name */ + sizeof(Tensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + Tensor::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ + "Tensor Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tensor_methods, /* tp_methods */ + 0, /* tp_members */ + Tensor_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + Tensor::new_stub, /* tp_new */ }; + // dim() -------------------- static bool relevant_op(_Py_CODEUNIT c) { - switch (c) { - case STORE_NAME: - case STORE_GLOBAL: - case STORE_FAST: - case STORE_DEREF: - return true; - default: - return false; - } + switch(c) { + case STORE_NAME: + case STORE_GLOBAL: + case STORE_FAST: + case STORE_DEREF: + return true; + default: + return false; + } } static mpy::object create_dim(mpy::object name, mpy::handle size) { - auto d = Dim::create(std::move(name)); - if (!mpy::is_none(size)) { - d->set_size(mpy::to_int(size)); - } - return std::move(d); + auto d = Dim::create(std::move(name)); + if (!mpy::is_none(size)) { + d->set_size(mpy::to_int(size)); + } + return std::move(d); } static mpy::object create_dimlist(mpy::object name, mpy::handle size) { - auto d = DimList::create(std::move(name)); - if (!mpy::is_none(size)) { - if (mpy::is_int(size)) { - d->bind_len(mpy::to_int(size)); - } else { - mpy::sequence_view s(size); - d->bind_len(s.size()); - for (auto i : irange(d->size())) { - d->dims_[i]->set_size(mpy::to_int(s[i])); - } + auto d = DimList::create(std::move(name)); + if (!mpy::is_none(size)) { + if (mpy::is_int(size)) { + d->bind_len(mpy::to_int(size)); + } else { + mpy::sequence_view s(size); + d->bind_len(s.size()); + for (auto i : irange(d->size())) { + d->dims_[i]->set_size(mpy::to_int(s[i])); + } + } } - } - return std::move(d); + return std::move(d); } -// Python wrappers that make new reflection primitives available for older -// runtimes + + +// Python wrappers that make new reflection primitives available for older runtimes #if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif -namespace { +namespace{ struct PyInstDecoder { - PyInstDecoder(PyCodeObject* code_object, int lasti) - : code_object_(code_object), - code_(_PyCode_CODE(code_object)), - offset_(lasti / sizeof(_Py_CODEUNIT)) {} - // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols - // See https://github.com/pytorch/pytorch/issues/93854 - void next() { -#if IS_PYTHON_3_11_PLUS - offset_ += _PyOpcode_Caches[opcode()]; -#endif - offset_ += 1; - } - int opcode() { - auto r = _Py_OPCODE(code_[offset_]); -#if IS_PYTHON_3_11_PLUS - r = _PyOpcode_Deopt[r]; -#endif - return r; - } - int oparg() { - return _Py_OPARG(code_[offset_]); - } - - mpy::object name() { - mpy::object names; - switch (opcode()) { - case STORE_NAME: - case STORE_GLOBAL: - names = mpy::object::borrow(code_object_->co_names); - break; - case STORE_FAST: - names = mpy::object::steal(PyCode_GetVarnames(code_object_)); - break; - case STORE_DEREF: - names = mpy::object::steal(PyCode_GetCellvars(code_object_)); - break; - default: - return mpy::object(); - } - return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); - } - - private: - PyCodeObject* code_object_; - _Py_CODEUNIT* code_; - int offset_; + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} + // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols + // See https://github.com/pytorch/pytorch/issues/93854 + void next() { + #if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; + #endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); + #if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; + #endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } + + mpy::object name() { + mpy::object names; + switch(opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = mpy::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = mpy::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = mpy::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return mpy::object(); + } + return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } +private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; }; -template -static PyObject* _dims( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - Py_ssize_t specified_ndims = -1; - Py_ssize_t found_ndims = 0; - Py_ssize_t sizes = -1; - mpy::handle n = Py_None; - mpy::handle py_sizes = Py_None; - - if (nargs || kwnames) { - mpy::vector_args va(args, nargs, kwnames); - va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); - if (!mpy::is_none(py_sizes)) { - sizes = mpy::sequence_view(py_sizes).size(); - specified_ndims = sizes; - } - if (!mpy::is_none(n)) { - specified_ndims = mpy::to_int(n); +template +static PyObject* _dims(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + Py_ssize_t specified_ndims = -1; + Py_ssize_t found_ndims = 0; + Py_ssize_t sizes = -1; + mpy::handle n = Py_None; + mpy::handle py_sizes = Py_None; + + if (nargs || kwnames) { + mpy::vector_args va(args, nargs, kwnames); + va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); + if (!mpy::is_none(py_sizes)) { + sizes = mpy::sequence_view(py_sizes).size(); + specified_ndims = sizes; + } + if (!mpy::is_none(n)) { + specified_ndims = mpy::to_int(n); + } } - } - PyThreadState* state = PyThreadState_GET(); - auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); - auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); - auto lasti = PyFrame_GetLasti(f.ptr()); - auto decoder = PyInstDecoder(c.ptr(), lasti); -#if IS_PYTHON_3_11_PLUS - // When py3.11 adapts bytecode lasti points to the precall - // rather than the call instruction after it - if (decoder.opcode() == PRECALL) { + PyThreadState* state = PyThreadState_GET(); + auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); + auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); + #if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { + decoder.next(); + } + #endif decoder.next(); - } -#endif - decoder.next(); - if (relevant_op(decoder.opcode())) { - found_ndims = 1; - } else if (decoder.opcode() == UNPACK_SEQUENCE) { - found_ndims = decoder.oparg(); - decoder.next(); - } - - if (specified_ndims == -1) { - if (found_ndims == 0) { - mpy::raise_error( - PyExc_SyntaxError, - "dims() must be assigned to a sequence of variable names or have argument n specified"); - } - specified_ndims = found_ndims; - } - if (found_ndims != specified_ndims) { - found_ndims = 0; // avoid taking the wrong names for dimensions - } - - auto genobject = [&](int i) -> mpy::object { - mpy::object name; - if (i < found_ndims) { - name = decoder.name(); - } - if (!name.ptr()) { - name = mpy::unicode_from_format("d%d", i); - found_ndims = 0; // once we fail at finding a name, we can find any more - } else { - decoder.next(); - } - return create_object( - std::move(name), - sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); - }; - if (sizes != -1 && sizes != specified_ndims) { - mpy::raise_error( - PyExc_ValueError, - "expected %d sizes but found %d", - int(specified_ndims), - int(sizes)); - } - if (specified_ndims == 1) { - return genobject(0).release(); - } - mpy::tuple result(specified_ndims); - for (int i = 0; i < specified_ndims; ++i) { - result.set(i, genobject(i)); - } - return result.release(); - PY_END(nullptr) + if (relevant_op(decoder.opcode())) { + found_ndims = 1; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); + } + + if (specified_ndims == -1) { + if (found_ndims == 0) { + mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified"); + } + specified_ndims = found_ndims; + } + if (found_ndims != specified_ndims) { + found_ndims = 0; // avoid taking the wrong names for dimensions + } + + auto genobject = [&](int i) -> mpy::object { + mpy::object name; + if (i < found_ndims) { + name = decoder.name(); + } + if (!name.ptr()) { + name = mpy::unicode_from_format("d%d", i); + found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); + } + return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); + }; + if (sizes != -1 && sizes != specified_ndims) { + mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes)); + } + if (specified_ndims == 1) { + return genobject(0).release(); + } + mpy::tuple result(specified_ndims); + for (int i = 0; i < specified_ndims; ++i) { + result.set(i, genobject(i)); + } + return result.release(); + PY_END(nullptr) } struct DotPart { - Slice dims; - size_t total_size = 1; - void append(Arena& A, mpy::hdl d) { - total_size *= d->size(); - dims.append(A, d); - } + Slice dims; + size_t total_size = 1; + void append(Arena& A, mpy::hdl d) { + total_size *= d->size(); + dims.append(A, d); + } }; -template +template static at::ArrayRef as_array_ref(Slice t) { - return at::ArrayRef(t.begin(), t.end()); -} - -static TensorRef dot_prepare( - Arena& A, - std::initializer_list parts, - const TensorInfo& t) { - Slice new_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - new_levels.extend(A, p.dims); - } - auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); - if (!needs_reshape) { - return r; - } - Slice view; - for (auto p : parts) { - view.append(A, p.total_size); - } - return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); -} - -static mpy::object dot_finish( - Arena& A, - std::initializer_list parts, - at::Tensor r) { - Slice result_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - result_levels.extend(A, p.dims); - } - if (needs_reshape) { - Slice new_size; - for (auto l : result_levels) { - new_size.append(A, l.dim()->size()); - } - r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); - } - return Tensor::from_positional(A, std::move(r), result_levels, true); -} - -static mpy::object dot( - Arena& A, - TensorInfo lhs, - TensorInfo rhs, - Slice sum) { - auto lhs_strides = lhs.tensor->strides(); - auto rhs_strides = rhs.tensor->strides(); - - DotPart lro_dims; - DotPart lo_dims; - DotPart ro_dims; - DotPart lr_dims; - - auto insert_dim = [&](mpy::hdl d, - std::optional lhs_idx, - std::optional rhs_idx) { - bool reduced = sum.contains(d); - int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; - int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; - if (reduced) { - // lr - lr_dims.append(A, d); + return at::ArrayRef(t.begin(), t.end()); +} + +static TensorRef dot_prepare(Arena& A, std::initializer_list parts, const TensorInfo& t) { + Slice new_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + new_levels.extend(A, p.dims); + } + auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); + if (!needs_reshape) { + return r; + } + Slice view; + for (auto p : parts) { + view.append(A, p.total_size); + } + return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); +} + +static mpy::object dot_finish(Arena& A, std::initializer_list parts, at::Tensor r) { + Slice result_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + result_levels.extend(A, p.dims); + } + if (needs_reshape) { + Slice new_size; + for (auto l : result_levels) { + new_size.append(A, l.dim()->size()); + } + r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + } + return Tensor::from_positional(A, std::move(r), result_levels, true); +} + + + +static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice sum) { + auto lhs_strides = lhs.tensor->strides(); + auto rhs_strides = rhs.tensor->strides(); + + DotPart lro_dims; + DotPart lo_dims; + DotPart ro_dims; + DotPart lr_dims; + + auto insert_dim = [&] (mpy::hdl d, std::optional lhs_idx, std::optional rhs_idx) { + bool reduced = sum.contains(d); + int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; + int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; + if (reduced) { + // lr + lr_dims.append(A, d); + } else { + if ((lhs_stride == 0) == (rhs_stride == 0)) { + // lro + lro_dims.append(A, d); + } else if (lhs_stride != 0) { + // lo + lo_dims.append(A, d); + } else { + AT_ASSERT(rhs_stride != 0); + ro_dims.append(A, d); + } + } + }; + + + auto rhs_seen = A.allocate(rhs.levels.size()); + std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); + + for (auto i : lhs.levels.enumerate()) { + auto d = lhs.levels[i]; + auto rhs_idx = rhs.levels.index(d); + if (rhs_idx) { + rhs_seen[*rhs_idx] = true; + } + insert_dim(d.dim(), i, rhs_idx); + } + + for (auto i : rhs.levels.enumerate()) { + if (rhs_seen[i]) { + continue; + } + auto d = rhs.levels[i]; + insert_dim(d.dim(), std::nullopt, i); + } + + if (lr_dims.dims.size() != sum.size()) { + for (auto & d : sum) { + if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { + mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr()); + } + } + } + + // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; + // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; + + // no batch, just call mm + if (lro_dims.dims.size() != 0) { + auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); + return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); } else { - if ((lhs_stride == 0) == (rhs_stride == 0)) { - // lro - lro_dims.append(A, d); - } else if (lhs_stride != 0) { - // lo - lo_dims.append(A, d); - } else { - AT_ASSERT(rhs_stride != 0); - ro_dims.append(A, d); - } - } - }; - - auto rhs_seen = A.allocate(rhs.levels.size()); - std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); - - for (auto i : lhs.levels.enumerate()) { - auto d = lhs.levels[i]; - auto rhs_idx = rhs.levels.index(d); - if (rhs_idx) { - rhs_seen[*rhs_idx] = true; - } - insert_dim(d.dim(), i, rhs_idx); - } - - for (auto i : rhs.levels.enumerate()) { - if (rhs_seen[i]) { - continue; - } - auto d = rhs.levels[i]; - insert_dim(d.dim(), std::nullopt, i); - } - - if (lr_dims.dims.size() != sum.size()) { - for (auto& d : sum) { - if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { - mpy::raise_error( - DimensionBindError(), - "summing over non-existent dimension %S", - d.dim().ptr()); - } - } - } - - // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; - // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << - // " " << lr_dims.dims << "\n"; - - // no batch, just call mm - if (lro_dims.dims.size() != 0) { - auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); - return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); - } else { - auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); - return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); - } -} - -static PyObject* test_c( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - - Arena A; - Slice s(A, 3, 4, 5); - AT_ASSERT(s.size() == 3 && s.capacity() == 8); - AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); - s.append(A, 6); - AT_ASSERT(s[3] == 6); - for (int i : irange(10)) { - s.append(A, i); - } - AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); - - Slice s2(A, -1, -2, -3); - AT_ASSERT(s2[1] == -2 && s[0] == 3); - - auto ss = s.slice(1, 2); - AT_ASSERT(ss.size() == 1); - AT_ASSERT(ss[0] == 4); - AT_ASSERT(ss.capacity() == 1); - ss.append(A, -4); - AT_ASSERT(ss.size() == 2 && ss[1] == -4); - ss[0] = 3; - AT_ASSERT(s[1] == 4); - - s.insert(A, s.slice(1, 4), ss); - AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); - - auto sz = s.size(); - s.insert(A, s.slice(1, 1), 4); - AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); - - Slice d(A, 0, 1, 2, 3, 4); - - Slice b(A, 0, 1, 2, 3, 4); - b.insert(A, b.slice(1, 1), d); - AT_ASSERT(b.size() == 10); - AT_ASSERT(b[1] == 0); - AT_ASSERT(b[5] == 4); - AT_ASSERT(b.back() == 4); - - Py_RETURN_NONE; - - PY_END(nullptr); -} - -static PyObject* order( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - if (kwnames) { - mpy::raise_error( - PyExc_TypeError, "unexpected keyword arguments %S", kwnames); - } - AT_ASSERT(nargs-- > 0); - Slice orig_levels; - Slice levels; - TensorRef data; - mpy::handle self = args++[0]; - bool has_device; - if (Tensor::check_exact(self)) { - auto t = Tensor::unchecked_wrap(self); - orig_levels = t->levels(); - data = t->tensor(A); - has_device = t->has_device(); - } else { - auto d = Dim::unchecked_wrap(self); - orig_levels.append(A, d); - data = d->range(); - has_device = false; - } - - Slice flat_positional_dims; - Slice> to_flatten; - levels.extend(A, orig_levels); - - int orig_ndim = ndim_of_levels(levels); - auto append = [&](DimEntry d) { - auto midx = levels.index(d); - if (!midx) { - if (d.is_positional()) { - mpy::raise_error( - PyExc_ValueError, - "tensor has %d positional dimensions, but %d specified, or it was specified twice", - int(orig_ndim), - int(d.position() + orig_ndim)); - } else { - mpy::raise_error( - PyExc_ValueError, - "tensor of dimensions %R does not contain dim %R or it was specified twice", - levels_to_tuple(orig_levels).ptr(), - d.dim().ptr()); - } - } - levels[*midx] = DimEntry(); - flat_positional_dims.append(A, d); - }; - - int n_new_positional = 0; - for (auto i : irange(nargs)) { - mpy::handle arg = args[i]; - DimEntry entry = _wrap_dim(arg, orig_ndim, false); - if (!entry.is_none()) { - append(entry); - ++n_new_positional; - } else if (DimList::check(arg)) { - auto dl = DimList::unchecked_wrap(arg); - for (mpy::obj& d : dl->dims_) { - append(mpy::hdl(d)); - ++n_new_positional; - } + auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); + return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + } + +} + +static PyObject* test_c(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + + Arena A; + Slice s(A, 3, 4, 5); + AT_ASSERT(s.size() == 3 && s.capacity() == 8); + AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); + s.append(A, 6); + AT_ASSERT(s[3] == 6); + for(int i : irange(10)) { + s.append(A, i); + } + AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); + + Slice s2(A, -1, -2, -3); + AT_ASSERT(s2[1] == -2 && s[0] == 3); + + auto ss = s.slice(1,2); + AT_ASSERT(ss.size() == 1); + AT_ASSERT(ss[0] == 4); + AT_ASSERT(ss.capacity() == 1); + ss.append(A, -4); + AT_ASSERT(ss.size() == 2 && ss[1] == -4); + ss[0] = 3; + AT_ASSERT(s[1] == 4); + + s.insert(A, s.slice(1, 4), ss); + AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); + + auto sz = s.size(); + s.insert(A, s.slice(1, 1), 4); + AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + + + Slice d(A, 0, 1, 2, 3, 4); + + Slice b(A, 0, 1, 2, 3, 4); + b.insert(A, b.slice(1,1), d); + AT_ASSERT(b.size() == 10); + AT_ASSERT(b[1] == 0); + AT_ASSERT(b[5] == 4); + AT_ASSERT(b.back() == 4); + + Py_RETURN_NONE; + + PY_END(nullptr); +} + + +static PyObject* order(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + if (kwnames) { + mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } + AT_ASSERT(nargs-- > 0); + Slice orig_levels; + Slice levels; + TensorRef data; + mpy::handle self = args++[0]; + bool has_device; + if (Tensor::check_exact(self)) { + auto t = Tensor::unchecked_wrap(self); + orig_levels = t->levels(); + data = t->tensor(A); + has_device = t->has_device(); } else { - ++n_new_positional; - if (!mpy::is_sequence(arg)) { - mpy::raise_error( - PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); - } - mpy::sequence_view sq(arg); - auto N = sq.size(); - to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); - for (auto j : irange(N)) { - DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); - if (e.is_none()) { - mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); - } - append(e); - } - } - } - - int insert_point = -1; - Slice new_levels; - for (auto l : levels) { - if (l.is_none()) { - continue; - } - if (l.is_positional()) { - if (insert_point == -1) { + auto d = Dim::unchecked_wrap(self); + orig_levels.append(A, d); + data = d->range(); + has_device = false; + } + + Slice flat_positional_dims; + Slice> to_flatten; + levels.extend(A, orig_levels); + + int orig_ndim = ndim_of_levels(levels); + auto append = [&](DimEntry d) { + auto midx = levels.index(d); + if (!midx) { + if (d.is_positional()) { + mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim)); + } else { + mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); + } + } + levels[*midx] = DimEntry(); + flat_positional_dims.append(A, d); + }; + + int n_new_positional = 0; + for (auto i :irange(nargs)) { + mpy::handle arg = args[i]; + DimEntry entry = _wrap_dim(arg, orig_ndim, false); + if (!entry.is_none()) { + append(entry); + ++n_new_positional; + } else if (DimList::check(arg)) { + auto dl = DimList::unchecked_wrap(arg); + for (mpy::obj & d : dl->dims_) { + append(mpy::hdl(d)); + ++n_new_positional; + } + } else { + ++n_new_positional; + if (!mpy::is_sequence(arg)) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); + } + mpy::sequence_view sq(arg); + auto N = sq.size(); + to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); + for (auto j : irange(N)) { + DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); + if (e.is_none()) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); + } + append(e); + } + } + } + + int insert_point = -1; + Slice new_levels; + for (auto l : levels) { + if (l.is_none()) { + continue; + } + if (l.is_positional()) { + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } + } + new_levels.append(A, l); + } + if (insert_point == -1) { insert_point = new_levels.size(); new_levels.extend(A, flat_positional_dims); - } } - new_levels.append(A, l); - } - if (insert_point == -1) { - insert_point = new_levels.size(); - new_levels.extend(A, flat_positional_dims); - } - at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); + at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); - if (to_flatten.size()) { - Slice view; - auto sz = ndata.sizes(); - // before the new positional dims - for (auto i : irange(0, insert_point)) { - view.append(A, sz[i]); - } - int i = 0; - for (auto to_flat : to_flatten) { - for (; i < to_flat.first; ++i) { - view.append(A, sz[insert_point + i]); - } - int64_t new_size = 1; - int last = i + to_flat.second; - for (; i < last; ++i) { - new_size *= sz[insert_point + i]; - } - view.append(A, new_size); - } - for (; i < flat_positional_dims.size(); ++i) { - view.append(A, sz[insert_point + i]); - } - // after the new positional dims - for (auto i : - irange(insert_point + flat_positional_dims.size(), levels.size())) { - view.append(A, sz[i]); - } - // we shorted the number of dimension, so remove them from new levels - // we will renumber them later - auto n_to_remove = flat_positional_dims.size() - n_new_positional; - new_levels.insert( - A, - new_levels.slice(insert_point, insert_point + n_to_remove), - Slice()); - ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); - } - - // renumber the positional dimension - int seen = 0; - for (auto i : new_levels.reversed_enumerate()) { - if (new_levels[i].is_positional() || - (i >= insert_point && i < insert_point + n_new_positional)) { - new_levels[i] = --seen; - } - } - return Tensor::from_positional(A, std::move(ndata), new_levels, has_device) - .release(); - - PY_END(nullptr) -} - -static PyObject* expand( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs-- > 0); - auto info = TensorInfo::create(A, args++[0], false); - for (auto i : irange(nargs)) { - if (!Dim::check(args[i])) { - maybeInitializeGlobals(); - mpy::vector_args vargs(args - 1, nargs + 1, kwnames); - if (THPVariable_Check(args[-1])) { - return torch_Tensor_expand.call_vector(vargs).release(); - } else { - return __torch_function__(A, torch_Tensor_expand, vargs, false) - .release(); - } - } - } - const at::Tensor& data = *info.tensor; - auto levels = info.levels; - Slice new_levels; - Slice sz; - Slice sd; - for (auto i : irange(nargs)) { - auto d = Dim::unchecked_wrap(args[i]); - if (levels.contains(d) || new_levels.contains(d)) { - mpy::raise_error( - DimensionBindError(), - "expanding dimension %R already exists in tensor with dims", - d.ptr()); - } - new_levels.append(A, d); - sz.append(A, d->size()); - sd.append(A, 0); - } - new_levels.extend(A, levels); - at::IntArrayRef osz = data.sizes(); - at::IntArrayRef osd = data.strides(); - sz.extend(A, osz.begin(), osz.end()); - sd.extend(A, osd.begin(), osd.end()); - at::Tensor ndata = data.as_strided( - at::IntArrayRef(sz.begin(), sz.end()), - at::IntArrayRef(sd.begin(), sd.end()), - data.storage_offset()); - return Tensor::from_positional( - A, std::move(ndata), new_levels, info.has_device) - .release(); - PY_END(nullptr) -} - -static void _bind_dims_to_size( - Arena& A, - int64_t sz, - int64_t sd, - Slice> dims, - Slice& nsz, - Slice& nsd) { - int64_t rhs_prod = 1; - for (auto i : dims.enumerate()) { - if (!dims[i]->is_bound()) { - for (auto j : irange(i + 1, dims.size())) { - if (!dims[j]->is_bound()) { - mpy::raise_error( - DimensionBindError(), - "cannot infer the sizes of two dimensions at once %R and %R", - dims[i].ptr(), - dims[j].ptr()); - } - rhs_prod *= dims[j]->size(); - } - if (sz % rhs_prod != 0) { + if (to_flatten.size()) { + Slice view; + auto sz = ndata.sizes(); + // before the new positional dims + for (auto i : irange(0, insert_point)) { + view.append(A, sz[i]); + } + int i = 0; + for (auto to_flat : to_flatten) { + for (;i < to_flat.first; ++i) { + view.append(A, sz[insert_point + i]); + } + int64_t new_size = 1; + int last = i + to_flat.second; + for (; i < last; ++i) { + new_size *= sz[insert_point + i]; + } + view.append(A, new_size); + } + for (; i < flat_positional_dims.size(); ++i) { + view.append(A, sz[insert_point + i]); + } + // after the new positional dims + for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { + view.append(A, sz[i]); + } + // we shorted the number of dimension, so remove them from new levels + // we will renumber them later + auto n_to_remove = flat_positional_dims.size() - n_new_positional; + new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice()); + ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + } + + // renumber the positional dimension + int seen = 0; + for (auto i : new_levels.reversed_enumerate()) { + if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { + new_levels[i] = --seen; + } + } + return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); + + PY_END(nullptr) +} + +static PyObject* expand(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs-- > 0); + auto info = TensorInfo::create(A, args++[0], false); + for (auto i : irange(nargs)) { + if (!Dim::check(args[i])) { + maybeInitializeGlobals(); + mpy::vector_args vargs(args - 1, nargs + 1, kwnames); + if (THPVariable_Check(args[-1])) { + return torch_Tensor_expand.call_vector(vargs).release(); + } else { + return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); + } + } + } + const at::Tensor& data = *info.tensor; + auto levels = info.levels; + Slice new_levels; + Slice sz; + Slice sd; + for (auto i : irange(nargs)) { + auto d = Dim::unchecked_wrap(args[i]); + if (levels.contains(d) || new_levels.contains(d)) { + mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr()); + } + new_levels.append(A, d); + sz.append(A, d->size()); + sd.append(A, 0); + } + new_levels.extend(A, levels); + at::IntArrayRef osz = data.sizes(); + at::IntArrayRef osd = data.strides(); + sz.extend(A, osz.begin(), osz.end()); + sd.extend(A, osd.begin(), osd.end()); + at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); + return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); + PY_END(nullptr) +} + + +static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, + Slice> dims, Slice& nsz, Slice& nsd) { + int64_t rhs_prod = 1; + for (auto i : dims.enumerate()) { + if (!dims[i]->is_bound()) { + for (auto j : irange(i + 1, dims.size())) { + if (!dims[j]->is_bound()) { + mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr()); + } + rhs_prod *= dims[j]->size(); + } + if (sz % rhs_prod != 0) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?")); + } + mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr()); + } + int64_t inferred_size = sz / rhs_prod; + dims[i]->set_size(inferred_size); + rhs_prod = sz; + break; + } + rhs_prod *= dims[i]->size(); + } + if (rhs_prod != sz) { mpy::tuple tup(dims.size()); for (auto j : dims.enumerate()) { - tup.set( - j, - dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) - : mpy::unicode_from_string("?")); - } - mpy::raise_error( - DimensionBindError(), - "inferred dimension does not evenly fit into larger dimension: %d vs %R", - (int)sz, - tup.ptr()); - } - int64_t inferred_size = sz / rhs_prod; - dims[i]->set_size(inferred_size); - rhs_prod = sz; - break; - } - rhs_prod *= dims[i]->size(); - } - if (rhs_prod != sz) { - mpy::tuple tup(dims.size()); - for (auto j : dims.enumerate()) { - tup.set(j, mpy::object::borrow(dims[j])); - } - mpy::raise_error( - DimensionBindError(), - "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", - (int)sz, - (int)rhs_prod, - tup.ptr()); - } - auto new_strides = A.allocate(dims.size()); - auto prev_stride = sd; - for (auto i : dims.reversed_enumerate()) { - new_strides[i] = prev_stride; - prev_stride = dims[i]->size() * prev_stride; - } - for (auto i : dims.enumerate()) { - nsd.append(A, new_strides[i]); - nsz.append(A, dims[i]->size()); - } + tup.set(j, mpy::object::borrow(dims[j])); + } + mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr()); + } + auto new_strides = A.allocate(dims.size()); + auto prev_stride = sd; + for (auto i : dims.reversed_enumerate()) { + new_strides[i] = prev_stride; + prev_stride = dims[i]->size()*prev_stride; + } + for (auto i : dims.enumerate()) { + nsd.append(A, new_strides[i]); + nsz.append(A, dims[i]->size()); + } } static bool has_dims(mpy::handle d) { - return Dim::check_exact(d) || Tensor::check_exact(d); + return Dim::check_exact(d) || Tensor::check_exact(d); } struct IndexingInfo { - bool can_call_original; // if true, then it is safe to just call getitem or - // setitem, these objects do not need special handling - bool advanced_indexing; // requires actual lookup - TensorRef self; - Slice flat_inputs; - Slice result_levels; - bool has_device; + bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling + bool advanced_indexing; // requires actual lookup + TensorRef self; + Slice flat_inputs; + Slice result_levels; + bool has_device; }; -} // namespace - -IndexingInfo getsetitem_flat( - Arena& A, - TensorInfo self_info, - Slice input, - Slice keys, - Slice values, - bool has_dimpacks_or_none); -namespace { +} + +IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none); +namespace{ Slice as_slice(mpy::tuple_view tv) { - PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(), 0); - return Slice( - (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); + PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); + return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); } Slice as_slice(mpy::list_view tv) { - PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0); - return Slice( - (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); -} - -bool maybe_dimpack( - Slice& elements, - mpy::handle s, - bool check_first = true) { - // can we avoid rechecking? - if (mpy::list_view::check(s)) { - mpy::list_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } - } - // can we avoid rechecking? - if (mpy::tuple_view::check(s)) { - mpy::tuple_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } - } - return false; + PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); + return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); +} + + +bool maybe_dimpack(Slice& elements, mpy::handle s, bool check_first=true) { + // can we avoid rechecking? + if (mpy::list_view::check(s)) { + mpy::list_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + // can we avoid rechecking? + if (mpy::tuple_view::check(s)) { + mpy::tuple_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + return false; }; bool is_dimpack(mpy::handle s) { - Slice e; - return maybe_dimpack(e, s); + Slice e; + return maybe_dimpack(e, s); } mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { - at::Tensor rtensor; - if (iinfo.advanced_indexing) { - auto self_hdl = handle_from_tensor(A, iinfo.self); - auto tup = slice_to_tuple(iinfo.flat_inputs); - // std::cout << "calling original getindex " << self_hdl << " " << tup << - // "\n"; - auto pytensor = mpy::object::checked_steal( - THPVariable_getitem(self_hdl.ptr(), tup.ptr())); - rtensor = THPVariable_Unpack(pytensor.ptr()); - } else { - // std::cout << "skipping original getindex\n"; - rtensor = *iinfo.self; - } - // std::cout << "returning (from_positional)\n"; - return Tensor::from_positional( - A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); -} - -mpy::object index( - Arena& A, - mpy::handle self, - mpy::handle dims, - mpy::handle indices) { - maybeInitializeGlobals(); - Slice dims_list; - Slice indices_list; - // we allow for matching single dims to multiple dims, - // so we first have to normalize everything into the case where there is a - // list on lhs and the rhs - bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); - bool rhs_list = - mpy::tuple_view::check(indices) || mpy::list_view::check(indices); - if (lhs_list && rhs_list) { - mpy::sequence_view dv(dims); - mpy::sequence_view ind(indices); - Py_ssize_t N = dv.size(); - if (N != ind.size()) { - mpy::raise_error( - PyExc_TypeError, - "dims (%d) and indices (%d) must have the same length", - int(N), - int(ind.size())); - } - for (auto i : irange(N)) { - dims_list.append(A, A.autorelease(dv[i])); - indices_list.append(A, A.autorelease(ind[i])); - } - } else { - dims_list.append(A, dims); - indices_list.append(A, indices); - } - - // dims being indexed can be grouped together into a single index space, and - // we have to flatten them int a single dimension before we can index them... - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - Slice new_levels; - Slice to_flatten; - Slice dims_list_flat; - auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { - auto d = _wrap_dim(s, ndim, false); - if (d.is_none()) { - mpy::raise_error( - PyExc_TypeError, - "expected a dimension specifyer but found %R", - s.ptr()); - } - return d; - }; - auto dim_not_present = [&](DimEntry d) { - if (d.is_positional()) { - mpy::raise_error( - PyExc_TypeError, - "dimension %d not in tensor of %d dimensions", - d.position() + ndim, - ndim); + at::Tensor rtensor; + if (iinfo.advanced_indexing) { + auto self_hdl = handle_from_tensor(A, iinfo.self); + auto tup = slice_to_tuple(iinfo.flat_inputs); + // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; + auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); + rtensor = THPVariable_Unpack(pytensor.ptr()); } else { - mpy::raise_error( - PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); - } - }; - - for (auto i : dims_list.enumerate()) { - Slice m; - if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { - if (m.size() == 0) { - // plausible semantics work for this to have 0 elements (e.g. the index - // will always be 0) - dims_list_flat.append(A, DimEntry()); // value is just dropped - } - auto first = parse_dim_entry(m[0]); - dims_list_flat.append(A, first); - if (m.size() == 1) { - continue; - } - if (to_flatten.size() == 0) { - new_levels.extend(A, self_info.levels); - } - Slice rest; - for (auto i : irange(1, m.size())) { - auto d = parse_dim_entry(m[i]); - if (!new_levels.remove(A, d)) { - dim_not_present(d); - } - rest.append(A, d); - } - - auto first_idx = new_levels.index(first); - if (!first_idx) { - dim_not_present(first); - } - new_levels.insert( - A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); - to_flatten.extend(A, rest); + // std::cout << "skipping original getindex\n"; + rtensor = *iinfo.self; + } + // std::cout << "returning (from_positional)\n"; + return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); +} + +mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) { + maybeInitializeGlobals(); + Slice dims_list; + Slice indices_list; + // we allow for matching single dims to multiple dims, + // so we first have to normalize everything into the case where there is a list on lhs and the rhs + bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); + bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices); + if (lhs_list && rhs_list) { + mpy::sequence_view dv(dims); + mpy::sequence_view ind(indices); + Py_ssize_t N = dv.size(); + if (N != ind.size()) { + mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size())); + } + for (auto i : irange(N)) { + dims_list.append(A, A.autorelease(dv[i])); + indices_list.append(A, A.autorelease(ind[i])); + } } else { - dims_list_flat.append(A, parse_dim_entry(dims_list[i])); - } - } - if (to_flatten.size() > 0) { - TensorRef rearranged = - _match_levels(A, self_info.tensor, self_info.levels, new_levels); - at::IntArrayRef sizes = rearranged->sizes(); - Slice new_sizes; - Slice reshape_levels; - for (auto i : new_levels.enumerate()) { - if (to_flatten.contains(new_levels[i])) { - new_sizes.back() *= sizes[i]; - } else { - new_sizes.append(A, sizes[i]); - reshape_levels.append(A, new_levels[i]); - } - } - self_info.tensor = A.autorelease(rearranged->reshape( - at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); - - self_info.levels = - reshape_levels; // note: we are using the first level in a flattened - // group to represent the group for the rest of the op - // we need to be careful not to rely the dimensions size - // because it doesn't match the size of the whole group - } - bool has_dimpacks = false; - for (auto idx : indices_list) { - if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { - has_dimpacks = true; - break; - } - } - IndexingInfo info = getsetitem_flat( - A, - self_info, - Slice(), - dims_list_flat, - indices_list, - has_dimpacks); - return invoke_getitem(A, info); + dims_list.append(A, dims); + indices_list.append(A, indices); + } + + // dims being indexed can be grouped together into a single index space, and we have to + // flatten them int a single dimension before we can index them... + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + Slice new_levels; + Slice to_flatten; + Slice dims_list_flat; + auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { + auto d = _wrap_dim(s, ndim, false); + if (d.is_none()) { + mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr()); + } + return d; + }; + auto dim_not_present = [&](DimEntry d) { + if (d.is_positional()) { + mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim); + } else { + mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); + } + }; + + for (auto i : dims_list.enumerate()) { + Slice m; + if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { + if (m.size() == 0) { + // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) + dims_list_flat.append(A, DimEntry()); // value is just dropped + } + auto first = parse_dim_entry(m[0]); + dims_list_flat.append(A, first); + if (m.size() == 1) { + continue; + } + if (to_flatten.size() == 0) { + new_levels.extend(A, self_info.levels); + } + Slice rest; + for (auto i : irange(1, m.size())) { + auto d = parse_dim_entry(m[i]); + if (!new_levels.remove(A, d)) { + dim_not_present(d); + } + rest.append(A, d); + } + + auto first_idx = new_levels.index(first); + if (!first_idx) { + dim_not_present(first); + } + new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); + to_flatten.extend(A, rest); + } else { + dims_list_flat.append(A, parse_dim_entry(dims_list[i])); + } + } + if (to_flatten.size() > 0) { + TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); + at::IntArrayRef sizes = rearranged->sizes(); + Slice new_sizes; + Slice reshape_levels; + for (auto i : new_levels.enumerate()) { + if (to_flatten.contains(new_levels[i])) { + new_sizes.back() *= sizes[i]; + } else { + new_sizes.append(A, sizes[i]); + reshape_levels.append(A, new_levels[i]); + } + } + self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); + + self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op + // we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group + } + bool has_dimpacks = false; + for (auto idx : indices_list) { + if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { + has_dimpacks = true; + break; + } + } + IndexingInfo info = getsetitem_flat(A, self_info, Slice(), dims_list_flat, indices_list, has_dimpacks); + return invoke_getitem(A, info); } // true -- the indices were flattened out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { - if (mpy::tuple_view::check(value)) { - return as_slice(mpy::tuple_view(value)); - } else if (mpy::list_view::check(value)) { - return as_slice(mpy::list_view(value)); - } else { - mpy::sequence_view sv(value); - Slice r; - for (auto i : sv.enumerate()) { - r.append(A, A.autorelease(sv[i])); + if (mpy::tuple_view::check(value)) { + return as_slice(mpy::tuple_view(value)); + } else if (mpy::list_view::check(value)) { + return as_slice(mpy::list_view(value)); + } else { + mpy::sequence_view sv(value); + Slice r; + for (auto i : sv.enumerate()) { + r.append(A, A.autorelease(sv[i])); + } + return r; } - return r; - } } bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { - if (mpy::tuple_view::check(index)) { - indices.extend(A, as_slice(mpy::tuple_view(index))); - return true; - } else if (THPVariable_Check(index.ptr())) { - indices.append(A, index); - return false; - } else if (!mpy::is_sequence(index)) { + if (mpy::tuple_view::check(index)) { + indices.extend(A, as_slice(mpy::tuple_view(index))); + return true; + } else if (THPVariable_Check(index.ptr())) { + indices.append(A, index); + return false; + } else if (!mpy::is_sequence(index)) { + indices.append(A, index); + return false; + } + // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. + mpy::sequence_view sv(index); + if (sv.size() >= 32) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + for (auto i : sv.enumerate()) { + mpy::handle item; + try { + item = sv[i]; + } catch (mpy::exception_set & e) { + PyErr_Clear(); + indices.append(A, index); + return false; + } + if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + } indices.append(A, index); return false; - } - // a copy of treatSequenceAsTuple modified to add Dim and our wrapped - // tensors.. - mpy::sequence_view sv(index); - if (sv.size() >= 32) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - for (auto i : sv.enumerate()) { - mpy::handle item; - try { - item = sv[i]; - } catch (mpy::exception_set& e) { - PyErr_Clear(); - indices.append(A, index); - return false; - } - if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || - PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || - mpy::is_none(item) || has_dims(item)) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - } - indices.append(A, index); - return false; -} - -IndexingInfo getsetitem( - Arena& A, - mpy::handle self, - mpy::handle index, - bool tensors_have_dims) { - bool can_call_original_getitem = !tensors_have_dims; - - Slice input; - if (has_dims(index)) { - input.append(A, index); - } else { - bool is_sequence = extractIndices(A, index, input); - // nothing about first class dims here, fallback to getitem - if (can_call_original_getitem && !is_sequence) { - return {true}; - } - } - - int64_t dims_indexed = 0; - int64_t expanding_object = -1; - DimList* unbound_dim_list = nullptr; - auto check_expanding = [&](int64_t i) { - if (expanding_object != -1) { - mpy::raise_error( - DimensionBindError(), - "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", - (int)expanding_object, - (int)i); - } - expanding_object = i; - }; - Slice dimlists; - - // calculate how many dimensioned have been indexed in order to compute the - // size of ... or expand a potentially unbound dimension list. - - bool has_dimpacks_or_none = false; - for (auto i : input.enumerate()) { - mpy::handle s = input[i]; - if (Dim::check_exact(s) || Tensor::check_exact(s)) { - can_call_original_getitem = false; - ++dims_indexed; - } else if (s.ptr() == Py_Ellipsis) { - check_expanding(i); - } else if (DimList::check(s)) { - can_call_original_getitem = false; - auto dl = DimList::unchecked_wrap(s); - if (!dl->is_bound()) { - check_expanding(i); - unbound_dim_list = dl.ptr(); - } else { - dims_indexed += dl->dims_.size(); - } - dimlists.append(A, i); - } else if (mpy::is_none(s)) { - has_dimpacks_or_none = true; - } else if (is_dimpack(s)) { - can_call_original_getitem = false; - has_dimpacks_or_none = true; - ++dims_indexed; - } else { - ++dims_indexed; - } - } - - // at this point if we haven't seen any Dim objects, we also can fallback to - // the original getitem. - if (can_call_original_getitem) { - return {true}; - } - - // std::cout << "__getitem__ " << self << " " << index << "\n"; - - TensorInfo self_info = TensorInfo::create(A, self, false, true); - auto ndim = self_info.ndim(); - if (dims_indexed > ndim) { - mpy::raise_error( - PyExc_ValueError, - "at least %d indices were supplied but the tensor only has %d dimensions", - (int)dims_indexed, - (int)ndim); - } - // expand any unbound dimension list, or expand ... into individual : slices. - auto expanding_dims = ndim - dims_indexed; - if (expanding_object != -1) { - if (unbound_dim_list) { - unbound_dim_list->bind_len(expanding_dims); - } else { - // ... - Slice no_slices; - for (auto i : irange(expanding_dims)) { - (void)i; - no_slices.append(A, no_slice); - } - input.insert( - A, input.slice(expanding_object, expanding_object + 1), no_slices); - } - } - - // flatten out any dimensions stored in dimlist elements directly into the - // inputs std::cout << dimlists << " <- dim lists!\n"; - for (int64_t i = dimlists.size() - 1; i >= 0; --i) { - auto idx = dimlists[i]; - // we added more elements to input because of ... - // so we need to also adjust the index to get back to where the - // dimlist existed - if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { - idx += expanding_dims; - } - auto dl = DimList::unchecked_wrap(input[idx]); - // XXX would be better if we used an OwnedSlice in DimList - Slice more_dims( - (mpy::handle*)&*dl->dims_.begin(), (mpy::handle*)&*dl->dims_.end()); - input.insert(A, input.slice(idx, idx + 1), more_dims); - } - - return getsetitem_flat( - A, - self_info, - input, - Slice(), - Slice(), - has_dimpacks_or_none); -} -} // namespace -IndexingInfo getsetitem_flat( - Arena& A, - TensorInfo self_info, - Slice input, - Slice keys, - Slice values, - bool has_dimpacks_or_none) { - // At this point: - // ..., DimList have been eliminated - // Dim, Tensor, Tuple[Dim,...], int, slice still remain - - // we have to count how many times we see a dimension. - // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires - // advanced indexing. - Slice> seen_dims; - Slice seen_dims_nuses; - auto add_dim = [&](mpy::hdl entry) { - auto midx = seen_dims.index(entry); - if (!midx) { - seen_dims.append(A, entry); - seen_dims_nuses.append(A, 1); +} + +IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) { + bool can_call_original_getitem = !tensors_have_dims; + + Slice input; + if (has_dims(index)) { + input.append(A, index); } else { - ++seen_dims_nuses[*midx]; + bool is_sequence = extractIndices(A, index, input); + // nothing about first class dims here, fallback to getitem + if (can_call_original_getitem && !is_sequence) { + return { true }; + } } - }; - Slice input_it = input; + int64_t dims_indexed = 0; + int64_t expanding_object = -1; + DimList* unbound_dim_list = nullptr; + auto check_expanding = [&](int64_t i) { + if (expanding_object != -1) { + mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i); + } + expanding_object = i; + }; + Slice dimlists; + + // calculate how many dimensioned have been indexed in order to compute the size of ... + // or expand a potentially unbound dimension list. + + bool has_dimpacks_or_none = false; + for (auto i : input.enumerate()) { + mpy::handle s = input[i]; + if (Dim::check_exact(s) || Tensor::check_exact(s)) { + can_call_original_getitem = false; + ++dims_indexed; + } else if (s.ptr() == Py_Ellipsis) { + check_expanding(i); + } else if (DimList::check(s)) { + can_call_original_getitem = false; + auto dl = DimList::unchecked_wrap(s); + if (!dl->is_bound()) { + check_expanding(i); + unbound_dim_list = dl.ptr(); + } else { + dims_indexed += dl->dims_.size(); + } + dimlists.append(A, i); + } else if (mpy::is_none(s)) { + has_dimpacks_or_none = true; + } else if (is_dimpack(s)) { + can_call_original_getitem = false; + has_dimpacks_or_none = true; + ++dims_indexed; + } else { + ++dims_indexed; + } + } - Slice flat_inputs; - // flat inputs will start with an empty mpy::handle if the - // actual value is in the tensor-like object in the tensor info - Slice tensor_inputs; + // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. + if (can_call_original_getitem) { + return {true}; + } + + // std::cout << "__getitem__ " << self << " " << index << "\n"; + + TensorInfo self_info = TensorInfo::create(A, self, false, true); + auto ndim = self_info.ndim(); + if (dims_indexed > ndim) { + mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim); + } + // expand any unbound dimension list, or expand ... into individual : slices. + auto expanding_dims = ndim - dims_indexed; + if (expanding_object != -1) { + if (unbound_dim_list) { + unbound_dim_list->bind_len(expanding_dims); + } else { + // ... + Slice no_slices; + for (auto i : irange(expanding_dims)) { + (void) i; + no_slices.append(A, no_slice); + } + input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); + } + } - auto append_flat_handle = [&](mpy::handle h) { - flat_inputs.append(A, h); - tensor_inputs.append(A, TensorInfo()); - }; - TensorRef device_holding_tensor; - auto append_tensor_input = [&](TensorInfo ti) { - flat_inputs.append(A, mpy::handle()); - tensor_inputs.append(A, ti); - if (ti.has_device && !device_holding_tensor) { - device_holding_tensor = ti.tensor; + // flatten out any dimensions stored in dimlist elements directly into the inputs + // std::cout << dimlists << " <- dim lists!\n"; + for (int64_t i = dimlists.size() - 1; i >=0; --i) { + auto idx = dimlists[i]; + // we added more elements to input because of ... + // so we need to also adjust the index to get back to where the + // dimlist existed + if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { + idx += expanding_dims; + } + auto dl = DimList::unchecked_wrap(input[idx]); + // XXX would be better if we used an OwnedSlice in DimList + Slice more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end()); + input.insert(A, input.slice(idx, idx + 1), more_dims); } - }; - Slice nsz; - Slice nsd; - at::IntArrayRef sz = self_info.tensor->sizes(); - at::IntArrayRef sd = self_info.tensor->strides(); + return getsetitem_flat(A, self_info, input, Slice(), Slice(), has_dimpacks_or_none); +} +} +IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none) { + // At this point: + // ..., DimList have been eliminated + // Dim, Tensor, Tuple[Dim,...], int, slice still remain + - auto append_size = [&](int i) { - if (has_dimpacks_or_none) { - nsz.append(A, sz[i]); - nsd.append(A, sd[i]); - } - }; - // std::cout << "self levels: " << self_info.levels << "\n"; - - auto parse_nones = [&]() { - while (input_it.size() && mpy::is_none(input_it[0])) { - append_flat_handle(no_slice); - nsz.append(A, 1); - nsd.append(A, 0); - input_it = input_it.slice(1); - } - }; - - auto append_item = [&](int i, mpy::handle arg) { - if (Dim::check_exact(arg)) { - auto d = Dim::unchecked_wrap(arg); - d->set_size(sz[i]); - add_dim(d); - append_size(i); - append_flat_handle(arg); - return; - } - auto info = TensorInfo::create(A, arg, false, false); - if (info) { - append_size(i); - append_tensor_input(info); - for (auto il : info.levels) { - if (!il.is_positional()) { - add_dim(il.dim()); - } - } - return; + // we have to count how many times we see a dimension. + // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. + Slice> seen_dims; + Slice seen_dims_nuses; + auto add_dim = [&](mpy::hdl entry) { + auto midx = seen_dims.index(entry); + if (!midx) { + seen_dims.append(A, entry); + seen_dims_nuses.append(A, 1); + } else { + ++seen_dims_nuses[*midx]; + } + }; + + Slice input_it = input; + + Slice flat_inputs; + // flat inputs will start with an empty mpy::handle if the + // actual value is in the tensor-like object in the tensor info + Slice tensor_inputs; + + auto append_flat_handle = [&](mpy::handle h) { + flat_inputs.append(A, h); + tensor_inputs.append(A, TensorInfo()); + }; + TensorRef device_holding_tensor; + auto append_tensor_input = [&](TensorInfo ti) { + flat_inputs.append(A, mpy::handle()); + tensor_inputs.append(A, ti); + if (ti.has_device && !device_holding_tensor) { + device_holding_tensor = ti.tensor; + } + }; + + Slice nsz; + Slice nsd; + at::IntArrayRef sz = self_info.tensor->sizes(); + at::IntArrayRef sd = self_info.tensor->strides(); + + auto append_size = [&](int i) { + if (has_dimpacks_or_none) { + nsz.append(A, sz[i]); + nsd.append(A, sd[i]); + } + }; + // std::cout << "self levels: " << self_info.levels << "\n"; + + auto parse_nones = [&]() { + while (input_it.size() && mpy::is_none(input_it[0])) { + append_flat_handle(no_slice); + nsz.append(A, 1); + nsd.append(A, 0); + input_it = input_it.slice(1); + } + }; + + + auto append_item = [&](int i, mpy::handle arg) { + if (Dim::check_exact(arg)) { + auto d = Dim::unchecked_wrap(arg); + d->set_size(sz[i]); + add_dim(d); + append_size(i); + append_flat_handle(arg); + return; + } + auto info = TensorInfo::create(A, arg, false, false); + if (info) { + append_size(i); + append_tensor_input(info); + for (auto il : info.levels) { + if (!il.is_positional()) { + add_dim(il.dim()); + } + } + return; + } + + if (has_dimpacks_or_none) { + Slice mp; + if (maybe_dimpack(mp, arg)) { + // dim pack + Slice> dim_pack; + for (auto d : mp) { + dim_pack.append(A, Dim::wrap(d)); + add_dim(dim_pack.back()); + append_flat_handle(dim_pack.back()); + } + _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); + return; + } + } + + append_size(i); + append_flat_handle(arg); + }; + + // pair up the indexing expressions with dimension of self it indexes + // self may have first-class dims, which do not participate the indexing. + for (auto i : self_info.levels.enumerate()) { + auto l = self_info.levels[i]; + auto idx = keys.index(l); + if (idx) { + append_item(i, values[*idx]); + } else if (l.is_positional()) { + // grab and index from the positional list + parse_nones(); + if (!input_it.size()) { + // we might have fewer indices than tensor dimensions, + // which implicitly indexes the remaining dimensions with : + append_flat_handle(no_slice); + append_size(i); + } else { + mpy::handle arg = input_it[0]; + input_it = input_it.slice(1); + append_item(i, arg); + } + } else { + add_dim(l.dim()); + append_flat_handle(l.dim()); + append_size(i); + } } + // any training Nones may have no existing dimension associated with them in self. + parse_nones(); + // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. if (has_dimpacks_or_none) { - Slice mp; - if (maybe_dimpack(mp, arg)) { - // dim pack - Slice> dim_pack; - for (auto d : mp) { - dim_pack.append(A, Dim::wrap(d)); - add_dim(dim_pack.back()); - append_flat_handle(dim_pack.back()); - } - _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); - return; - } - } - - append_size(i); - append_flat_handle(arg); - }; - - // pair up the indexing expressions with dimension of self it indexes - // self may have first-class dims, which do not participate the indexing. - for (auto i : self_info.levels.enumerate()) { - auto l = self_info.levels[i]; - auto idx = keys.index(l); - if (idx) { - append_item(i, values[*idx]); - } else if (l.is_positional()) { - // grab and index from the positional list - parse_nones(); - if (!input_it.size()) { - // we might have fewer indices than tensor dimensions, - // which implicitly indexes the remaining dimensions with : - append_flat_handle(no_slice); - append_size(i); - } else { - mpy::handle arg = input_it[0]; - input_it = input_it.slice(1); - append_item(i, arg); - } - } else { - add_dim(l.dim()); - append_flat_handle(l.dim()); - append_size(i); - } - } - // any training Nones may have no existing dimension associated with them in - // self. - parse_nones(); - - // we have to restride the tensor to collapse dimension packs and introduce - // our none dimensions. - if (has_dimpacks_or_none) { - self_info.tensor = A.autorelease(self_info.tensor->as_strided( - at::IntArrayRef(nsz.begin(), nsz.end()), - at::IntArrayRef(nsd.begin(), nsd.end()), - self_info.tensor->storage_offset())); - } - - // figure out what the shape of the indexing tensors will be - // and what the shape of the resulting tensor will be - Slice result_levels; - Slice index_levels; - int64_t tensor_insert_point = -1; - bool requires_getindex = false; - auto mark_tensor_index = [&] { - if (tensor_insert_point == -1) { - tensor_insert_point = result_levels.size(); - } else if (tensor_insert_point != result_levels.size()) { - tensor_insert_point = 0; - } - }; - for (auto i : flat_inputs.enumerate()) { - auto inp = flat_inputs[i]; - if (tensor_inputs[i]) { - requires_getindex = true; - mark_tensor_index(); - for (auto l : tensor_inputs[i].levels) { - // std::cout << "Consider to add " << l << "\n"; - if (!index_levels.contains(l)) { - index_levels.append(A, l); - } - } - } else if (Dim::check_exact(inp)) { - auto d = Dim::unchecked_wrap(inp); - // dimensions used once are just binding operations - if (1 == seen_dims_nuses[*seen_dims.index(d)]) { - flat_inputs[i] = no_slice; - result_levels.append(A, d); - } else { - requires_getindex = true; - flat_inputs[i] = mpy::handle(); - tensor_inputs[i] = TensorInfo{ - d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; - if (!index_levels.contains(d)) { - index_levels.append(A, d); - } - mark_tensor_index(); - } - } else { - if (inp.ptr() != no_slice.ptr()) { - requires_getindex = true; - } - if (!mpy::is_int(inp)) { - // note: actual positional indexes are accurately computed later - result_levels.append(A, -1); - } - } - } - - // indexing dimensions appear in the tensor at the _first use of a tensor_ in - // the indexing. So insert the indexing leveles into the result klevels at - // this spot - if (tensor_insert_point != -1) { - result_levels.insert( - A, - result_levels.slice(tensor_insert_point, tensor_insert_point), - index_levels); - } - - // std::cout << "flat inputs: " << flat_inputs << "\n"; - // std::cout << "result_levels: " << result_levels << "\n"; - // std::cout << "index_levels: " << index_levels << "\n"; - - // get all the tensors to be the right shape for indexing - if (requires_getindex) { + self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); + } + + + // figure out what the shape of the indexing tensors will be + // and what the shape of the resulting tensor will be + Slice result_levels; + Slice index_levels; + int64_t tensor_insert_point = -1; + bool requires_getindex = false; + auto mark_tensor_index = [&] { + if (tensor_insert_point == -1) { + tensor_insert_point = result_levels.size(); + } else if (tensor_insert_point != result_levels.size()) { + tensor_insert_point = 0; + } + }; for (auto i : flat_inputs.enumerate()) { - if (tensor_inputs[i]) { - AT_ASSERT(!flat_inputs[i].ptr()); - // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << - // "\n"; - TensorRef t = tensor_inputs[i].tensor; - if (!tensor_inputs[i].has_device && device_holding_tensor) { - t = A.autorelease(t->to(device_holding_tensor->device())); - } - flat_inputs[i] = handle_from_tensor( - A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); - } - } - } - - // previously we didn't know how many positional dimensions there would be so - // we couldn't number them right so fill it in now. - auto seen_positionals = 0; - for (auto i : result_levels.reversed_enumerate()) { - if (result_levels[i].is_positional()) { - result_levels[i] = -(++seen_positionals); - } - } - - return IndexingInfo{ - false, - requires_getindex, - self_info.tensor, - flat_inputs, - result_levels, - self_info.has_device}; -} -namespace { -mpy::object __getitem__(Arena& A, mpy::handle self, mpy::handle index) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self)); - if (iinfo.can_call_original) { - return mpy::object::checked_steal( - THPVariable_getitem(self.ptr(), index.ptr())); - } - - return invoke_getitem(A, iinfo); -} - -void __setitem__( - Arena& A, - mpy::handle self, - mpy::handle index, - mpy::handle rhs) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); - if (iinfo.can_call_original) { - if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { - throw mpy::exception_set(); - } - return; - } - - auto rhs_info = TensorInfo::create(A, rhs, false, false); - if (rhs_info) { // otherwise rhs can be a scalar... - for (auto l : rhs_info.levels) { - if (!iinfo.result_levels.contains(l)) { - if (l.is_positional()) { - mpy::raise_error( - DimensionBindError(), - "rhs contains too many dimensions (%d) compared to indexed value (%d)", - ndim_of_levels(iinfo.result_levels), - rhs_info.ndim()); - } else { - auto tup = levels_to_tuple(iinfo.result_levels); - mpy::raise_error( - DimensionBindError(), - "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", - l.dim().ptr(), - tup.ptr()); + auto inp = flat_inputs[i]; + if(tensor_inputs[i]) { + requires_getindex = true; + mark_tensor_index(); + for (auto l : tensor_inputs[i].levels) { + // std::cout << "Consider to add " << l << "\n"; + if (!index_levels.contains(l)) { + index_levels.append(A, l); + } + } + } else if (Dim::check_exact(inp)) { + auto d = Dim::unchecked_wrap(inp); + // dimensions used once are just binding operations + if (1 == seen_dims_nuses[*seen_dims.index(d)]) { + flat_inputs[i] = no_slice; + result_levels.append(A, d); + } else { + requires_getindex = true; + flat_inputs[i] = mpy::handle(); + tensor_inputs[i] = TensorInfo {d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; + if (!index_levels.contains(d)) { + index_levels.append(A, d); + } + mark_tensor_index(); + } + } else { + if (inp.ptr() != no_slice.ptr()) { + requires_getindex = true; + } + if (!mpy::is_int(inp)) { + // note: actual positional indexes are accurately computed later + result_levels.append(A, -1); + } + } + } + + // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert + // the indexing leveles into the result klevels at this spot + if (tensor_insert_point != -1) { + result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); + } + + // std::cout << "flat inputs: " << flat_inputs << "\n"; + // std::cout << "result_levels: " << result_levels << "\n"; + // std::cout << "index_levels: " << index_levels << "\n"; + + // get all the tensors to be the right shape for indexing + if (requires_getindex) { + for (auto i : flat_inputs.enumerate()) { + if (tensor_inputs[i]) { + AT_ASSERT(!flat_inputs[i].ptr()); + // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; + TensorRef t = tensor_inputs[i].tensor; + if (!tensor_inputs[i].has_device && device_holding_tensor) { + t = A.autorelease(t->to(device_holding_tensor->device())); + } + flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); + } + } + } + + // previously we didn't know how many positional dimensions there would be so we couldn't number them right + // so fill it in now. + auto seen_positionals = 0; + for (auto i : result_levels.reversed_enumerate()) { + if (result_levels[i].is_positional()) { + result_levels[i] = -(++seen_positionals); + } + } + + return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; +} +namespace{ +mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self)); + if (iinfo.can_call_original) { + return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); + } + + return invoke_getitem(A, iinfo); +} + + + +void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); + if (iinfo.can_call_original) { + if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { + throw mpy::exception_set(); } - } + return; + } + + auto rhs_info = TensorInfo::create(A, rhs, false, false); + if (rhs_info) { // otherwise rhs can be a scalar... + for (auto l : rhs_info.levels) { + if (!iinfo.result_levels.contains(l)) { + if (l.is_positional()) { + mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); + } else { + auto tup = levels_to_tuple(iinfo.result_levels); + mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr()); + } + } + } + auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); + rhs = handle_from_tensor(A, rhs_matched); } - auto rhs_matched = - _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); - rhs = handle_from_tensor(A, rhs_matched); - } - self = handle_from_tensor(A, iinfo.self); + self = handle_from_tensor(A, iinfo.self); - if (iinfo.advanced_indexing) { - auto tup = slice_to_tuple(iinfo.flat_inputs); - if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { - throw mpy::exception_set(); + if (iinfo.advanced_indexing) { + auto tup = slice_to_tuple(iinfo.flat_inputs); + if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + } else { + torch_Tensor_copy_.call(self, rhs); } - } else { - torch_Tensor_copy_.call(self, rhs); - } } -} // namespace +} PyObject* Tensor_getitem(PyObject* self, PyObject* index) { - Arena A; - PY_BEGIN - return __getitem__(A, self, index).release(); - PY_END(nullptr); + Arena A; + PY_BEGIN + return __getitem__(A, self, index).release(); + PY_END(nullptr); } int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { - Arena A; - PY_BEGIN - __setitem__(A, self, index, value); - return 0; - PY_END(-1); -} - -namespace { -PyObject* py___getitem__( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 2); - return __getitem__(A, args[0], args[1]).release(); - PY_END(nullptr) -} - -PyObject* py___setitem__( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 3); - __setitem__(A, args[0], args[1], args[2]); - Py_RETURN_NONE; - PY_END(nullptr) -} - -PyObject* py_index( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, dims, indices; - va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); - return index(A, self, dims, indices).release(); - PY_END(nullptr) -} - -PyObject* py_stack( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle tensors, new_dim, dim; - va.parse( - "stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); - - Slice result_levels; - Slice infos; - mpy::sequence_view sv(tensors); - auto new_dim_d = Dim::wrap(new_dim); - for (auto i : sv.enumerate()) { - infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); - for (auto l : infos.back().levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } - } - new_dim_d->set_size(infos.size()); - std::vector inputs; - inputs.reserve(infos.size()); - for (auto in : infos) { - inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); - } - auto ndim = ndim_of_levels(result_levels); - int64_t rawdim = 0; - if (dim.ptr()) { - auto d = _wrap_dim(dim, ndim, false); - auto idx = result_levels.index(d); - if (!idx) { - mpy::raise_error( - PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); - } - rawdim = *idx; - } - auto result = at::stack(inputs, rawdim); - result_levels.insert(A, rawdim, new_dim_d); - return Tensor::from_positional(A, std::move(result), result_levels, true) - .release(); - PY_END(nullptr) -} - -PyObject* py_split( - PyObject* _, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, split_size_or_sections, dim; - va.parse( - "split", - {"self", "split_size_or_sections", "dim"}, - {&self, &split_size_or_sections, &dim}, - 2); - bool dim_is_object = dim.ptr() && Dim::check_exact(dim); - Slice sizes; - - bool all_dims = true; - bool all_ints = true; - - if (!mpy::is_int(split_size_or_sections)) { - mpy::sequence_view sv(split_size_or_sections); + Arena A; + PY_BEGIN + __setitem__(A, self, index, value); + return 0; + PY_END(-1); +} + +namespace{ +PyObject* py___getitem__(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 2); + return __getitem__(A, args[0], args[1]).release(); + PY_END(nullptr) +} + +PyObject* py___setitem__(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 3); + __setitem__(A, args[0], args[1], args[2]); + Py_RETURN_NONE; + PY_END(nullptr) +} + + +PyObject* py_index(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, dims, indices; + va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); + return index(A, self, dims, indices).release(); + PY_END(nullptr) +} + + +PyObject* py_stack(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle tensors, new_dim, dim; + va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); + + Slice result_levels; + Slice infos; + mpy::sequence_view sv(tensors); + auto new_dim_d = Dim::wrap(new_dim); for (auto i : sv.enumerate()) { - sizes.append(A, A.autorelease(sv[i])); - if (Dim::check_exact(sizes.back())) { - all_ints = false; - } else { - all_dims = false; - } - } - } - if (all_ints) { - if (dim_is_object) { - mpy::raise_error( - PyExc_TypeError, - "when dim is specified as a Dim object, split sizes must also be dimensions."); - } - // call original split (if self has dimensions this will use torch function - // to do the split) - return torch_Tensor_split - .call_vector(mpy::vector_args(args, nargs, kwnames)) - .release(); - } - if (!all_dims) { - mpy::raise_error( - PyExc_TypeError, "split list must be ints or dims but got a mix"); - } - - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - if (!dim_is_object && ndim == 0) { - mpy::raise_error( - PyExc_TypeError, "split expects at least a 1-dimension tensor"); - } - DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; - - auto idx = self_info.levels.index(dim_l); - if (!idx) { - if (!dim.ptr()) { - dim = A.autorelease(mpy::from_int(0)); - } - mpy::raise_error( - PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); - } - Slice indices; - - int64_t total_size = 0; - Slice unbound; - for (auto i : sizes.enumerate()) { - auto d = Dim::unchecked_wrap(sizes[i]); - if (d->is_bound()) { - indices.append(A, d->size()); - total_size += indices.back(); - } else { - indices.append(A, 0); - unbound.append(A, i); - } - } - auto tensor_size = self_info.tensor->sizes()[*idx]; - - if (unbound.size()) { - if (total_size > tensor_size) { - mpy::raise_error( - PyExc_TypeError, - "sizes of target dimensions add up to more (%d) than source dim (%d)", - int(total_size), - int(tensor_size)); - } - auto remaining_size = tensor_size - total_size; - auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); - for (auto u : unbound) { - auto sz = std::min(chunk_size, remaining_size); - Dim::unchecked_wrap(sizes[u])->set_size(sz); - indices[u] = sz; - remaining_size -= sz; - } - } else if (tensor_size != total_size) { - mpy::raise_error( - PyExc_TypeError, - "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", - int(total_size), - int(tensor_size)); - } - - auto result_tensors = self_info.tensor->split_with_sizes( - at::IntArrayRef(indices.begin(), indices.end()), *idx); - mpy::tuple result(result_tensors.size()); - Slice new_levels; - new_levels.extend(A, self_info.levels); - for (auto i : sizes.enumerate()) { - new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); - result.set( - i, - Tensor::from_positional( - A, std::move(result_tensors[i]), new_levels, true)); - } - - return result.release(); - - PY_END(nullptr) + infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); + for (auto l : infos.back().levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + new_dim_d->set_size(infos.size()); + std::vector inputs; + inputs.reserve(infos.size()); + for (auto in : infos) { + inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); + } + auto ndim = ndim_of_levels(result_levels); + int64_t rawdim = 0; + if (dim.ptr()) { + auto d = _wrap_dim(dim, ndim, false); + auto idx = result_levels.index(d); + if (!idx) { + mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); + } + rawdim = *idx; + } + auto result = at::stack(inputs, rawdim); + result_levels.insert(A, rawdim, new_dim_d); + return Tensor::from_positional(A, std::move(result), result_levels, true).release(); + PY_END(nullptr) +} + +PyObject* py_split(PyObject *_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, split_size_or_sections, dim; + va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2); + bool dim_is_object = dim.ptr() && Dim::check_exact(dim); + Slice sizes; + + bool all_dims = true; + bool all_ints = true; + + if (!mpy::is_int(split_size_or_sections)) { + mpy::sequence_view sv(split_size_or_sections); + for (auto i : sv.enumerate()) { + sizes.append(A, A.autorelease(sv[i])); + if (Dim::check_exact(sizes.back())) { + all_ints = false; + } else { + all_dims = false; + } + } + } + if (all_ints) { + if (dim_is_object) { + mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions."); + } + // call original split (if self has dimensions this will use torch function to do the split) + return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release(); + } + if (!all_dims) { + mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix"); + } + + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + if (!dim_is_object&& ndim == 0) { + mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor"); + } + DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; + + auto idx = self_info.levels.index(dim_l); + if (!idx) { + if (!dim.ptr()) { + dim = A.autorelease(mpy::from_int(0)); + } + mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + } + Slice indices; + + int64_t total_size = 0; + Slice unbound; + for (auto i : sizes.enumerate()) { + auto d = Dim::unchecked_wrap(sizes[i]); + if (d->is_bound()) { + indices.append(A, d->size()); + total_size += indices.back(); + } else { + indices.append(A, 0); + unbound.append(A, i); + } + } + auto tensor_size = self_info.tensor->sizes()[*idx]; + + if (unbound.size()) { + if (total_size > tensor_size) { + mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size)); + } + auto remaining_size = tensor_size - total_size; + auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); + for (auto u : unbound) { + auto sz = std::min(chunk_size, remaining_size); + Dim::unchecked_wrap(sizes[u])->set_size(sz); + indices[u] = sz; + remaining_size -= sz; + } + } else if (tensor_size != total_size) { + mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size)); + } + + auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); + mpy::tuple result(result_tensors.size()); + Slice new_levels; + new_levels.extend(A, self_info.levels); + for (auto i : sizes.enumerate()) { + new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); + result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); + } + + return result.release(); + + PY_END(nullptr) } Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { - auto de = _wrap_dim(d, N, keepdim); - Slice r; - if (!de.is_none()) { - r.append(A, de); - } else { - mpy::sequence_view sq(d); - for (auto i : sq.enumerate()) { - r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); + auto de = _wrap_dim(d, N, keepdim); + Slice r; + if (!de.is_none()) { + r.append(A, de); + } else { + mpy::sequence_view sq(d); + for (auto i : sq.enumerate()) { + r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); + } } - } - return r; + return r; } struct WrappedOperator : public mpy::base { - mpy::object orig; - PyMethodDef method_def; - mpy::object name, doc; - - bool is_pointwise = false; - int64_t dim_offset = 0; - int64_t keepdim_offset = 1; - std::string dim_name; - bool single_dim = false; - bool reduce = true; - - static PyTypeObject Type; - - void init( - mpy::object orig_, - PyCFunction wrapper_implementation, - std::string dim_name_ = "") { - orig = std::move(orig_); - method_def.ml_meth = wrapper_implementation; - name = orig.attr("__name__"); - doc = orig.attr("__doc__"); - dim_name = std::move(dim_name_); - if (!mpy::is_none(doc) && !dim_name.empty()) { - doc = mpy::unicode_from_format( - "%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", - doc.ptr(), - dim_name.c_str()); - } - method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); - method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); - method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; - } - - mpy::object function() { - return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); - } + mpy::object orig; + PyMethodDef method_def; + mpy::object name, doc; + + bool is_pointwise = false; + int64_t dim_offset = 0; + int64_t keepdim_offset = 1; + std::string dim_name; + bool single_dim = false; + bool reduce = true; + + static PyTypeObject Type; + + void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") { + orig = std::move(orig_); + method_def.ml_meth = wrapper_implementation; + name = orig.attr("__name__"); + doc = orig.attr("__doc__"); + dim_name = std::move(dim_name_); + if (!mpy::is_none(doc) && !dim_name.empty()) { + doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str()); + } + method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); + method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); + method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; + } + + mpy::object function() { + return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + } + }; -} // namespace +} PyTypeObject WrappedOperator::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.WrappedOperator", /* tp_name */ - sizeof(WrappedOperator), /* tp_basicsize */ - 0, /* tp_itemsize */ - WrappedOperator::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ + "_C.WrappedOperator", /* tp_name */ + sizeof(WrappedOperator), /* tp_basicsize */ + 0, /* tp_itemsize */ + WrappedOperator::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ - "Wrapped Object Holder", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - WrappedOperator::new_stub, /* tp_new */ + "Wrapped Object Holder", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + WrappedOperator::new_stub, /* tp_new */ }; -namespace { -PyObject* patched_dim_method( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - auto self = WrappedOperator::unchecked_wrap(self_); - PY_BEGIN - - mpy::vector_args va(args, nargs, kwnames); - - auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - return idx == -1 ? mpy::handle() : va[idx]; - }; - Slice patched_args; - patched_args.extend(A, va.begin(), va.end()); - auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - if (idx == -1) { - mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); - } - patched_args[idx] = value; - }; - - auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); - if (!dim.ptr()) { - auto info = TensorInfo::create(A, args[0], true); - EnableAllLayers l(A, info.levels); - l.inplace_update_layers(info.batchedtensor, info.levels); - patched_args[0] = handle_from_tensor(A, info.batchedtensor); +namespace{ +PyObject* patched_dim_method(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + auto self = WrappedOperator::unchecked_wrap(self_); + PY_BEGIN + + mpy::vector_args va(args, nargs, kwnames); + + auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + return idx == -1 ? mpy::handle() : va[idx]; + }; + Slice patched_args; + patched_args.extend(A, va.begin(), va.end()); + auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + if (idx == -1) { + mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); + } + patched_args[idx] = value; + }; + + auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); + if (!dim.ptr()) { + auto info = TensorInfo::create(A, args[0], true); + EnableAllLayers l(A, info.levels); + l.inplace_update_layers(info.batchedtensor, info.levels); + patched_args[0] = handle_from_tensor(A, info.batchedtensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); + } + + auto info = TensorInfo::create(A, args[0]); + auto keepdim = false; + if (self->reduce) { + auto py_keepdim = _getarg("keepdim", self->keepdim_offset); + if (py_keepdim.ptr()) { + keepdim = mpy::to_bool(py_keepdim); + } + } + + auto ndim = info.ndim(); + auto dims = _wrap_dims(A, dim, ndim, keepdim); + Slice dim_indices; + auto seen = A.allocate(info.levels.size()); + std::fill(seen, seen + info.levels.size(), false); + + for (auto d : dims) { + auto midx = info.levels.index(d); + if (!midx) { + auto tup = levels_to_tuple(info.levels); + mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr()); + } + seen[*midx] = true; + dim_indices.append(A, *midx); + } + Slice new_levels; + if (self->reduce && !keepdim) { + for (auto i : info.levels.enumerate()) { + if (!seen[i]) { + new_levels.append(A, info.levels[i]); + } + } + } else { + new_levels = info.levels; + } + mpy::object py_indices; + if (dim_indices.size() == 1) { + py_indices = mpy::from_int(dim_indices[0]); + } else { + mpy::tuple tup(dim_indices.size()); + for (auto i : dim_indices.enumerate()) { + tup.set(i, mpy::from_int(dim_indices[i])); + } + py_indices = std::move(tup); + } + _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); + patched_args[0] = handle_from_tensor(A, info.tensor); auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device) - .release(); - } - - auto info = TensorInfo::create(A, args[0]); - auto keepdim = false; - if (self->reduce) { - auto py_keepdim = _getarg("keepdim", self->keepdim_offset); - if (py_keepdim.ptr()) { - keepdim = mpy::to_bool(py_keepdim); - } - } - - auto ndim = info.ndim(); - auto dims = _wrap_dims(A, dim, ndim, keepdim); - Slice dim_indices; - auto seen = A.allocate(info.levels.size()); - std::fill(seen, seen + info.levels.size(), false); - - for (auto d : dims) { - auto midx = info.levels.index(d); - if (!midx) { - auto tup = levels_to_tuple(info.levels); - mpy::raise_error( - PyExc_ValueError, - "Tensor with dimensions %R does not contain one of %R\n", - tup.ptr(), - dim.ptr()); - } - seen[*midx] = true; - dim_indices.append(A, *midx); - } - Slice new_levels; - if (self->reduce && !keepdim) { - for (auto i : info.levels.enumerate()) { - if (!seen[i]) { - new_levels.append(A, info.levels[i]); - } - } - } else { - new_levels = info.levels; - } - mpy::object py_indices; - if (dim_indices.size() == 1) { - py_indices = mpy::from_int(dim_indices[0]); - } else { - mpy::tuple tup(dim_indices.size()); - for (auto i : dim_indices.enumerate()) { - tup.set(i, mpy::from_int(dim_indices[i])); - } - py_indices = std::move(tup); - } - _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); - patched_args[0] = handle_from_tensor(A, info.tensor); - auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(Tensor::from_positional( - A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); - } - return h; - }; - return tree_map(A, wrap, r).release(); - PY_END(nullptr) -} - -PyObject* _wrap( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - -#define ARGS(_) \ - _(mpy::handle, orig) \ - _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ - _(mpy::handle, dim_name) _(mpy::handle, single_dim) \ - _(mpy::handle, reduce) - MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) - - std::string dim_name_str; - if (dim_name.ptr()) { - dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); - } else { - dim_name_str = "dim"; - } - auto info = WrappedOperator::create( - mpy::object::borrow(orig), - (PyCFunction)(void*)patched_dim_method, - std::move(dim_name_str)); - if (dim_offset.ptr()) { - info->dim_offset = mpy::to_int(dim_offset); - } - if (keepdim_offset.ptr()) { - info->keepdim_offset = mpy::to_int(keepdim_offset); - } - - if (single_dim.ptr()) { - info->single_dim = mpy::to_bool(single_dim); - } - if (reduce.ptr()) { - info->reduce = mpy::to_bool(reduce); - } - return info->function().release(); -#undef ARGS - - PY_END(nullptr) -} - -PyObject* call_torch_function( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - Arena A; - maybeInitializeGlobals(); - auto info = WrappedOperator::unchecked_wrap(self); - return __torch_function__( - A, - info->orig, - mpy::vector_args(args, nargs, kwnames), - info->is_pointwise) - .release(); - PY_END(nullptr) -} - -PyObject* _wrap_method( - PyObject* self, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - AT_ASSERT(nargs == 2); - // XXX - ignore python function wrapped, we will call torch function directly - mpy::handle orig = args[0]; - if (!pointwise.ptr()) { - auto dim = mpy::import("functorch.dim"); - pointwise = dim.attr("pointwise"); - } - auto info = WrappedOperator::create( - mpy::object::borrow(orig), (PyCFunction)(void*)call_torch_function); - info->is_pointwise = pointwise.contains(orig); - return PyInstanceMethod_New(info->function().release()); - PY_END(nullptr); -} - -PyObject* Tensor_sum( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - auto self_ = Tensor::unchecked_wrap(args[0]); - auto d = self_->delayed(); - if (!d) { - return _Tensor_sum.call_vector(va).release(); - } - mpy::handle self, dim, keepdim, dtype; - va.parse( - "sum", - {"self", "dim", "keepdim", "dtype"}, - {&self, &dim, &keepdim, &dtype}, - 1, - 1); - - if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { - // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; - return _Tensor_sum.call_vector(va).release(); - } - auto levels = self_->levels(); - - auto N = ndim_of_levels(levels); - auto reduced_dims = _wrap_dims(A, dim, N, false); - - return dot(A, - TensorInfo::create(A, d->args[0], false), - TensorInfo::create(A, d->args[1], false), - reduced_dims) - .release(); - PY_END(nullptr) -} - -PyObject* _parse_test( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - maybeInitializeGlobals(); - - int required = mpy::to_int(args[0]); - int kwonly = mpy::to_int(args[1]); - - mpy::vector_args va(args + 2, nargs - 2, kwnames); - - mpy::handle a, b, c, d; - va.parse( - "_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); - mpy::tuple r(4); - r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); - r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); - r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); - r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); - return r.release(); - - PY_END(nullptr) -} - -PyObject* _set_pointwise_optimize( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - mpy::handle value; - mpy::vector_args va(args, nargs, kwnames); - va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); - pointwise_optimize = mpy::to_bool(value); - Py_RETURN_NONE; - PY_END(nullptr) -} - -PyObject* _patch_tensor_class( - PyObject* self_, - PyObject* const* args, - Py_ssize_t nargs, - PyObject* kwnames) { - PY_BEGIN - - auto torch = mpy::import("torch"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - replaceMappingIfMatches(py_TensorBase); - - Py_RETURN_NONE; - PY_END(nullptr) + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); + } + return h; + }; + return tree_map(A, wrap, r).release(); + PY_END(nullptr) } +PyObject* _wrap(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + + #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ + _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce) + MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) + + std::string dim_name_str; + if (dim_name.ptr()) { + dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); + } else { + dim_name_str = "dim"; + } + auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); + if (dim_offset.ptr()) { + info->dim_offset = mpy::to_int(dim_offset); + } + if (keepdim_offset.ptr()) { + info->keepdim_offset = mpy::to_int(keepdim_offset); + } + + if (single_dim.ptr()) { + info->single_dim = mpy::to_bool(single_dim); + } + if (reduce.ptr()) { + info->reduce = mpy::to_bool(reduce); + } + return info->function().release(); + #undef ARGS + + PY_END(nullptr) +} + +PyObject* call_torch_function(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + Arena A; + maybeInitializeGlobals(); + auto info = WrappedOperator::unchecked_wrap(self); + return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release(); + PY_END(nullptr) +} + +PyObject* _wrap_method(PyObject *self, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + AT_ASSERT(nargs == 2); + // XXX - ignore python function wrapped, we will call torch function directly + mpy::handle orig = args[0]; + if (!pointwise.ptr()) { + auto dim = mpy::import("functorch.dim"); + pointwise = dim.attr("pointwise"); + } + auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function); + info->is_pointwise = pointwise.contains(orig); + return PyInstanceMethod_New(info->function().release()); + PY_END(nullptr); +} + + +PyObject* Tensor_sum(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + auto self_ = Tensor::unchecked_wrap(args[0]); + auto d = self_->delayed(); + if (!d) { + return _Tensor_sum.call_vector(va).release(); + } + mpy::handle self, dim, keepdim, dtype; + va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1); + + if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { + // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; + return _Tensor_sum.call_vector(va).release(); + } + auto levels = self_->levels(); + + auto N = ndim_of_levels(levels); + auto reduced_dims = _wrap_dims(A, dim, N, false); + + return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); + PY_END(nullptr) +} + +PyObject* _parse_test(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + maybeInitializeGlobals(); + + int required = mpy::to_int(args[0]); + int kwonly = mpy::to_int(args[1]); + + mpy::vector_args va(args + 2, nargs - 2, kwnames); + + + mpy::handle a, b, c, d; + va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); + mpy::tuple r(4); + r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); + r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); + r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); + r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); + return r.release(); + + PY_END(nullptr) +} + +PyObject* _set_pointwise_optimize(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + mpy::handle value; + mpy::vector_args va(args, nargs, kwnames); + va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); + pointwise_optimize = mpy::to_bool(value); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* _patch_tensor_class(PyObject * self_, + PyObject *const *args, + Py_ssize_t nargs, + PyObject *kwnames) { + PY_BEGIN + + auto torch = mpy::import("torch"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + replaceMappingIfMatches(py_TensorBase); + + Py_RETURN_NONE; + PY_END(nullptr) +} + + const char* dims_doc = R"""( dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] @@ -3579,79 +3196,54 @@ Example:: )"""; PyMethodDef methods[] = { - {"dims", - (PyCFunction)(void*)_dims, - METH_FASTCALL | METH_KEYWORDS, - dims_doc}, - {"dimlists", - (PyCFunction)(void*)_dims, - METH_FASTCALL | METH_KEYWORDS}, - {"_test_c", (PyCFunction)(void*)test_c, METH_FASTCALL | METH_KEYWORDS}, - {"_wrap_method", - (PyCFunction)(void*)_wrap_method, - METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_from_positional", - (PyCFunction)(void*)py_Tensor_from_positional, - METH_FASTCALL | METH_KEYWORDS}, - {"__torch_function__", - (PyCFunction)(void*)py___torch_function__, - METH_FASTCALL | METH_KEYWORDS}, - {"tree_flatten", - (PyCFunction)(void*)py_tree_flatten, - METH_FASTCALL | METH_KEYWORDS}, - {"order", (PyCFunction)(void*)order, METH_FASTCALL | METH_KEYWORDS}, - {"index", (PyCFunction)(void*)py_index, METH_FASTCALL | METH_KEYWORDS}, - {"stack", (PyCFunction)(void*)py_stack, METH_FASTCALL | METH_KEYWORDS}, - {"split", (PyCFunction)(void*)py_split, METH_FASTCALL | METH_KEYWORDS}, - {"expand", (PyCFunction)(void*)expand, METH_FASTCALL | METH_KEYWORDS}, - {"__getitem__", - (PyCFunction)(void*)py___getitem__, - METH_FASTCALL | METH_KEYWORDS}, - {"__setitem__", - (PyCFunction)(void*)py___setitem__, - METH_FASTCALL | METH_KEYWORDS}, - {"_wrap", (PyCFunction)(void*)_wrap, METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_sum", - (PyCFunction)(void*)Tensor_sum, - METH_FASTCALL | METH_KEYWORDS}, - {"_parse_test", - (PyCFunction)(void*)_parse_test, - METH_FASTCALL | METH_KEYWORDS}, - {"_set_pointwise_optimize", - (PyCFunction)(void*)_set_pointwise_optimize, - METH_FASTCALL | METH_KEYWORDS}, - {"_patch_tensor_class", - (PyCFunction)(void*)_patch_tensor_class, - METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"dims", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS, dims_doc}, + {"dimlists", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS}, + {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, + {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, + {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, + {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, + {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, + {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, + {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, + {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, + {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, + {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, + {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, + {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, + {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, - "_C", /* name of module */ + "_C", /* name of module */ NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - methods}; -} // namespace + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + methods +}; +} PyObject* Dim_init() { - Arena A; - try { - mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); - Dim::ready(mod, "Dim"); - DimList::ready(mod, "DimList"); - Tensor::ready(mod, "Tensor"); - WrappedOperator::ready(mod, "_WrappedOperator"); - Py_INCREF(&PyInstanceMethod_Type); - PyModule_AddObject( - mod.ptr(), "_instancemethod", (PyObject*)&PyInstanceMethod_Type); - - initializeGlobals(A); - return mod.release(); - } catch (mpy::exception_set& err) { - return nullptr; - } + Arena A; + try { + mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); + Dim::ready(mod, "Dim"); + DimList::ready(mod, "DimList"); + Tensor::ready(mod, "Tensor"); + WrappedOperator::ready(mod, "_WrappedOperator"); + Py_INCREF(&PyInstanceMethod_Type); + PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type); + + initializeGlobals(A); + return mod.release(); + } catch(mpy::exception_set& err) { + return nullptr; + } } #endif diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 0eadce79c05fc..0b222c16bb787 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -583,6 +583,7 @@ "torch._C._dispatch_has_kernel", "torch._C._dispatch_is_alias_key", "torch._C._dispatch_is_included_in_alias", + "torch._C._dispatch_is_main_interpreter", "torch._C._dispatch_isTensorSubclassLike", "torch._C._dispatch_key_for_device", "torch._C._dispatch_key_name", diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 9497296c1a4c0..15efa62ae978a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -407,10 +407,10 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // associated with the TensorImpl. Swap this field as well. std::optional mb_obj_a = a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); std::optional mb_obj_b = b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( mb_obj_a.has_value() && mb_obj_b.has_value(), "Both tensors should have PyObjects tagged by the current python interpreter"); @@ -420,8 +420,10 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { a->cdata = b->cdata; b->cdata = tmp; - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); + a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); + b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index f289a286b19c7..f944bb5c5461e 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -586,7 +586,7 @@ static void set_tensor_attr_with_capsule( py::capsule& capsule, const char* attr_name) { std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto obj = mb_obj.value(); @@ -987,3 +987,7 @@ py::handle getTorchApiFunction(const c10::OperatorHandle& op) { c10::impl::PyInterpreter* getPyInterpreter() { return torch::detail::self_interpreter.get(); } + +bool isMainPyInterpreter() { + return torch::detail::self_interpreter.is_main_interpreter(); +} diff --git a/torch/csrc/PyInterpreter.h b/torch/csrc/PyInterpreter.h index 0ff9f79d02c27..82ca11e2c5d0c 100644 --- a/torch/csrc/PyInterpreter.h +++ b/torch/csrc/PyInterpreter.h @@ -10,4 +10,4 @@ TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op); // TODO: Move these to a proper namespace TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); -TORCH_PYTHON_API void initializeGlobalPyInterpreter(); +TORCH_PYTHON_API bool isMainPyInterpreter(); diff --git a/torch/csrc/PyInterpreterHooks.cpp b/torch/csrc/PyInterpreterHooks.cpp deleted file mode 100644 index fd1c997be0a08..0000000000000 --- a/torch/csrc/PyInterpreterHooks.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include - -namespace torch::detail { - -PyInterpreterHooks::PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs) {} - -c10::impl::PyInterpreter* PyInterpreterHooks::getPyInterpreter() const { - // Delegate to the existing implementation - return ::getPyInterpreter(); -} - -} // namespace torch::detail - -// Sigh, the registry doesn't support namespaces :( -using c10::impl::PyInterpreterHooksRegistry; -using c10::impl::RegistererPyInterpreterHooksRegistry; -using PyInterpreterHooks = torch::detail::PyInterpreterHooks; -// Register the implementation -REGISTER_PYTHON_HOOKS(PyInterpreterHooks); diff --git a/torch/csrc/PyInterpreterHooks.h b/torch/csrc/PyInterpreterHooks.h deleted file mode 100644 index 1def7b8c55ae6..0000000000000 --- a/torch/csrc/PyInterpreterHooks.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include - -namespace torch::detail { - -// Concrete implementation of PyInterpreterHooks -class PyInterpreterHooks : public c10::impl::PyInterpreterHooksInterface { - public: - explicit PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs); - - c10::impl::PyInterpreter* getPyInterpreter() const override; -}; - -} // namespace torch::detail diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 08112b41aaaed..cc682a2644af2 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -35,6 +35,7 @@ PyTypeObject* THPStorageClass = nullptr; PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, + c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { TORCH_CHECK( PyType_IsSubtype(type, &THPStorageType), @@ -42,7 +43,7 @@ PyObject* THPStorage_NewWithStorage( "Storage is not possible. Make sure your class inherits from Storage."); auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); if (maybe_pyobj.has_value() && maybe_pyobj.value()) { TORCH_CHECK( allow_preexisting_pyobj, @@ -77,7 +78,8 @@ PyObject* THPStorage_NewWithStorage( if (!c10::impl::HermeticPyObjectTLS::get_state()) { s->is_hermetic = false; const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); + storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), obj, status); } else { s->is_hermetic = true; } @@ -89,12 +91,17 @@ PyObject* THPStorage_NewWithStorage( PyObject* THPStorage_Wrap(c10::Storage storage) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(storage), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); std::optional maybe_pyobj = pyobj_slot->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); + c10::impl::PyInterpreterStatus status = + c10::impl::PyInterpreterStatus::TAGGED_BY_US; if (maybe_pyobj.has_value()) { auto obj = *maybe_pyobj; if (obj) { @@ -113,8 +120,15 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { return obj; } } + status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; + } else { + if (storage.use_count() <= 1) { + status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; + } else { + status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; + } } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status); } static bool THPStorage_isPreservable(THPStorage* self) { @@ -128,7 +142,8 @@ static bool THPStorage_isPreservable(THPStorage* self) { } if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/true) != (PyObject*)self) { + getPyInterpreter(), /*ignore_hermetic_tls=*/true) != + (PyObject*)self) { return false; } if (storage.use_count() <= 1) { @@ -146,10 +161,11 @@ static bool THPStorage_tryPreserve(THPStorage* self) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( + getPyInterpreter(), /*ignore_hermetic_tls=*/true); // NOTE: It is possible to just set the PyObjectSlot here, but the point is - // that we should have already set PyObjectSlot when the storage PyObject - // was created. + // that we should have already set PyObjectSlot when the storage PyObject was + // created. TORCH_INTERNAL_ASSERT( maybe_pyobj.has_value(), "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); @@ -357,7 +373,8 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt)); + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); // torch.Storage(size, *, ...) } else if (r.idx == 1) { @@ -370,7 +387,8 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt)); + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); // torch.Storage(sequence, *, ...) } else if (r.idx == 2) { @@ -394,7 +412,8 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt)); + device_opt), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); THPObjectPtr item; try { const auto& storage = THPStorage_Unpack(self); @@ -490,8 +509,10 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { /* resizable */ false, device_opt); - PyObject* _ret = - THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl)); + PyObject* _ret = THPStorage_NewWithStorage( + Py_TYPE(self), + std::move(new_storage_impl), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); return _ret; } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 698cd80548efa..ce86475d6a952 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -19,6 +19,7 @@ TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, + c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index da64bcfbd5008..8e5a99e4da7f7 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -390,7 +390,10 @@ static PyObject* THPStorage_fromFile( storage->set_nbytes(actual_nbytes); } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(storage), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index e58865bb60a8a..9f7d667613dc5 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -86,7 +86,8 @@ static PyObject* THPStorage_pyNewFilenameStorage( THManagedMapAllocator::makeDataPtr( "", handle.c_str(), flags, static_cast(size)), /*allocator=*/nullptr, - /*resizable=*/false)); + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); END_HANDLE_TH_ERRORS } @@ -181,7 +182,8 @@ static PyObject* THPStorage_newSharedFilename( THManagedMapAllocator::makeDataPtr( manager_handle, object_handle, flags, size), /*allocator=*/nullptr, - /*resizable=*/false)); + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); END_HANDLE_TH_ERRORS } @@ -195,7 +197,9 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { return nullptr; } return THPStorage_NewWithStorage( - THPStorageClass, at::new_shm_fd_storage(size)); + THPStorageClass, + at::new_shm_fd_storage(size), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); END_HANDLE_TH_ERRORS } @@ -274,7 +278,8 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { at::MapAllocator::makeDataPtr( at::WITH_FD, "", fd, flags, size, nullptr), /*allocator=*/nullptr, - /*resizable=*/false)); + /*resizable=*/false), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); END_HANDLE_TH_ERRORS } @@ -555,7 +560,10 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { base->set_resizable(false); base->set_received_cuda(true); - return THPStorage_NewWithStorage(THPStorageClass, std::move(base)); + return THPStorage_NewWithStorage( + THPStorageClass, + std::move(base), + c10::impl::PyInterpreterStatus::TAGGED_BY_US); #else TORCH_CHECK(false, "CUDA is not available"); #endif diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index c184dd63d2949..b0235da869fbc 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -209,6 +209,7 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, + c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); // clang-tidy gets confused by static const @@ -260,12 +261,16 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { } if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + return THPVariable_NewWithVar( + (PyTypeObject*)THPVariableClass, + var, + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); } std::optional mb_obj = var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); + c10::impl::PyInterpreterStatus status{}; if (mb_obj.has_value()) { auto obj = *mb_obj; if (obj) { @@ -290,17 +295,27 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR // being a thing, the PyObject field will get cleared when all references // to the Python object are removed. + status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; + } else { + // Assumption: if a Tensor has been shared across threads, this induces + // a refcount bump. Therefore, if the use count 1, we are the sole thread + // with access to this tensor and no race is possible. + if (var.use_count() <= 1) { + status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; + } else { + status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; + } } if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, var); + return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); } - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } static bool isResurrectable(THPVariable* self) { @@ -329,7 +344,8 @@ static bool isResurrectable(THPVariable* self) { } // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) != (PyObject*)self) { + getPyInterpreter(), /*ignore_hermetic_tls=*/false) != + (PyObject*)self) { return false; } return true; @@ -355,6 +371,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) { c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( + getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( @@ -570,7 +587,10 @@ static PyObject* THPVariable_as_subclass( // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; - return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + self.alias(), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -622,7 +642,10 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } - return THPVariable_NewWithVar((PyTypeObject*)cls, data); + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + data, + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -767,7 +790,10 @@ static PyObject* THPVariable_make_wrapper_subclass( tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); } - return THPVariable_NewWithVar((PyTypeObject*)cls, tensor); + return THPVariable_NewWithVar( + (PyTypeObject*)cls, + tensor, + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); END_HANDLE_TH_ERRORS } @@ -1795,6 +1821,7 @@ PyObject* THPVariable_pynew( return THPVariable_NewWithVar( type, tensor, + c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS } @@ -1847,7 +1874,8 @@ static int THPVariable_subclass_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) == (PyObject*)self) { + getPyInterpreter(), /*ignore_hermetic_tls=*/false) == + (PyObject*)self) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -2019,10 +2047,17 @@ static void THPVariable_subclass_dealloc(PyObject* self) { Py_DECREF(type); } -// Creates a new Python object for a Variable. +// Creates a new Python object for a Variable. The status parameter +// specifies what the interpreter tag status on the object is; for +// example, if you ran check_pyobj, the return optional of this object +// tells you if the tensor was already tagged or not so you can pass +// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where +// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. +// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, + c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { // Make sure that the reinterpret into a THPVariable* will be valid TORCH_CHECK( @@ -2033,7 +2068,7 @@ static PyObject* THPVariable_NewWithVar( // This function overwrite the Tensor's pyobj field without extra checks // Make sure it is not set otherwise we would leak memory auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); + getPyInterpreter(), /*ignore_hermetic_tls=*/false); // Under some circumstances, we may attempt to create a new Python // object for a variable that already has a Python object. The most common @@ -2115,7 +2150,8 @@ static PyObject* THPVariable_NewWithVar( // Normal codepath v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); + var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( + getPyInterpreter(), obj, status); if (check_has_torch_dispatch(obj)) { var.unsafeGetTensorImpl()->set_python_dispatch(true); } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 019ce2070634d..b2b0e848a7e79 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -209,10 +209,12 @@ class PythonKernelHolder : public c10::OperatorKernel { } }; -// @todo sahanp: Afait only register is used in the codebase. This can be -// removed / simplified static torch::_RegisterOrVerify register_or_verify() { - return torch::_RegisterOrVerify::REGISTER; + if (isMainPyInterpreter()) { + return torch::_RegisterOrVerify::REGISTER; + } else { + return torch::_RegisterOrVerify::VERIFY; + } } static py::object ophandle_call_boxed( @@ -285,6 +287,7 @@ void initDispatchBindings(PyObject* module) { .def( "reset", [](const py::object& self) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().reset(); return; }, @@ -294,6 +297,7 @@ void initDispatchBindings(PyObject* module) { .def( "def_", [](py::object self, const char* schema, const char* alias) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; @@ -307,6 +311,7 @@ void initDispatchBindings(PyObject* module) { .def( "def_legacy", [](py::object self, const char* schema) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def(torch::jit::parseSchema(schema)); return self; }, @@ -326,6 +331,7 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -343,6 +349,7 @@ void initDispatchBindings(PyObject* module) { const char* dispatch, const char* alias, const char* debug) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { @@ -363,6 +370,7 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -457,6 +465,7 @@ void initDispatchBindings(PyObject* module) { .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; @@ -471,6 +480,7 @@ void initDispatchBindings(PyObject* module) { bool with_keyset) { HANDLE_TH_ERRORS auto& lib = self.cast(); + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); if (func.is(py::module::import("torch.library") .attr("fallthrough_kernel"))) { lib.fallback( @@ -903,6 +913,8 @@ void initDispatchBindings(PyObject* module) { handle.setReportErrorCallback_(std::move(callback_obj)); }); + m.def( + "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); m.def("_dispatch_pystub", [](const char* name, const char* overload) { return c10::Dispatcher::singleton().getPyStub( c10::OperatorName(name, overload)); From 99cc3633f69c7830435ad0b18469dc08a7dcde45 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 23:17:38 +0000 Subject: [PATCH 358/457] Revert "[BE] Modify PyObjectSlot the assume only a single interpreter is in use (#158407)" This reverts commit d9426a81d2ab54f809a3b32a6ab2e606073fe66f. Reverted https://github.com/pytorch/pytorch/pull/158407 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78496147 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158288#issuecomment-3099826158)) --- c10/core/impl/PyObjectSlot.cpp | 14 ++++++- c10/core/impl/PyObjectSlot.h | 74 +++++++++++++++++++++++++++++++--- torch/csrc/Storage.cpp | 11 +++++ 3 files changed, 92 insertions(+), 7 deletions(-) diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 62af2eae8e37a..400903bc7a651 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -44,7 +44,19 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { if (interpreter) { return *interpreter; } - TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set"); + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + (*pyobj_interpreter_.load())->name()); +} + +bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) { + return interpreter == pyobj_interpreter(); +} + +bool PyObjectSlot::has_pyobj_nonhermetic() { + return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true) + .has_value(); } bool PyObjectSlot::owns_pyobj() { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index af8b9fa4d0ec7..4b9bcf1e4a1c3 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -28,7 +28,48 @@ struct C10_API PyObjectSlot { PyInterpreter* self_interpreter, PyObject* pyobj, PyInterpreterStatus status) { - pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + impl::PyInterpreter* expected = nullptr; + switch (status) { + case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: + // caller guarantees there is no multithreaded access; if there is + // no data race OK to do a relaxed store + pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + break; + case impl::PyInterpreterStatus::TAGGED_BY_US: + // no tagging is necessary, the tag is already correct + break; + case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: + // attempt to claim this TensorImpl with the specified interpreter + // tag + if (pyobj_interpreter_.compare_exchange_strong( + expected, self_interpreter, std::memory_order_acq_rel)) { + break; + } + // test if, actually, it was already tagged by us! this situation can't + // be caused by a race, but it could be caused by a situation + // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED + // (because they didn't pre-check the tag) when actually it was + // owned by the interpreter + if (expected == self_interpreter) { + break; + } + // fallthrough, we lost the race. We are guaranteed not to lose the + // race with ourself, as calls to init_pyobj with the same interpreter + // ID must be sequentialized by the GIL + [[fallthrough]]; + case impl::PyInterpreterStatus::TAGGED_BY_OTHER: + TORCH_CHECK( + false, + "cannot allocate PyObject for Tensor on interpreter ", + self_interpreter, + " that has already been used by another torch deploy interpreter ", + pyobj_interpreter_.load()); + } + + // we are the ONLY thread that can have gotten to this point. It is not + // possible to conflict with another zero interpreter as access is protected + // by GIL + // NB: owns_pyobj tag is initially false pyobj_ = pyobj; } @@ -56,16 +97,30 @@ struct C10_API PyObjectSlot { std::optional check_pyobj( PyInterpreter* self_interpreter, bool ignore_hermetic_tls = false) const { + // Note [Memory ordering on Python interpreter tag] impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter == nullptr) { + // NB: This never returns DEFINITELY_UNINITIALIZED because there is + // always the possibility that another thread races to initialize + // after we query here. The only time when we can conclude a tensor + // is definitely uninitialized is when we have just allocated it and + // it cannot have escaped to other threads yet return std::nullopt; - } - - if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return std::nullopt; + } else if (interpreter == self_interpreter) { + // NB: pyobj_ could still be null! + if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { + return std::nullopt; + } else { + return _unchecked_untagged_pyobj(); + } } else { - return _unchecked_untagged_pyobj(); + TORCH_CHECK( + false, + "cannot access PyObject for Tensor on interpreter ", + (*self_interpreter)->name(), + " that has already been used by another torch deploy interpreter ", + (*pyobj_interpreter_.load())->name()); } } @@ -75,6 +130,13 @@ struct C10_API PyObjectSlot { PyInterpreter& load_pyobj_interpreter() const; + // Check if the PyObjectSlot's interpreter is the same as the specified + // interpreter + bool check_interpreter(PyInterpreter* interpreter); + + // Check if the PyObjectSlot is holding a PyObject, owned or non-owned + bool has_pyobj_nonhermetic(); + bool owns_pyobj(); void set_owns_pyobj(bool b); diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index cc682a2644af2..d566dc666ebfe 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -98,6 +98,17 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); + // If the StorageImpl has a PyObject that is managed by a different + // interpreter than the current one, create a new StorageImpl that points to + // the same data and then create the Python storage from that. + // NOTE: This is only supposed to happen in MultiPy // codespell:ignore + if (pyobj_slot->has_pyobj_nonhermetic() && + !pyobj_slot->check_interpreter(getPyInterpreter())) { + return THPStorage_NewWithStorage( + THPStorageClass, + c10::newStorageImplFromRefcountedDataPtr(storage), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + } std::optional maybe_pyobj = pyobj_slot->check_pyobj( getPyInterpreter(), /*ignore_hermetic_tls=*/false); c10::impl::PyInterpreterStatus status = From 920f26c7617c3ae65142d69c3c65b1f6f111fd46 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 23:17:38 +0000 Subject: [PATCH 359/457] Revert "[BE] Remove __reduce_deploy__ (#158291)" This reverts commit 0b9fb91f17edfbc51ae36584dcb8350b2d8bb23b. Reverted https://github.com/pytorch/pytorch/pull/158291 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78496147 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158288#issuecomment-3099826158)) --- docs/source/conf.py | 1 + ...t-fx_backcompat_function_signatures.expect | 1 + torch/_dynamo/trace_rules.py | 1 + torch/fx/_lazy_graph_module.py | 5 +++++ torch/fx/graph_module.py | 21 +++++++++++++++++++ 5 files changed, 29 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8b2112c165e8a..34d8e9876b172 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1086,6 +1086,7 @@ "z3op", "z3str", # torch.fx.graph_module + "reduce_deploy_graph_module", "reduce_graph_module", "reduce_package_graph_module", # torch.fx.node diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 67ed33950249d..fab0dbd066761 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -29,6 +29,7 @@ torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.m torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode +torch.fx.graph_module.reduce_deploy_graph_module(importer: Callable, body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module torch.fx.interpreter.Interpreter.__init__(self, module: torch.nn.modules.module.Module, garbage_collect_values: bool = True, graph: Optional[torch.fx.graph.Graph] = None) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 0b222c16bb787..3889771334d4b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3472,6 +3472,7 @@ def _module_dir(m: types.ModuleType) -> Optional[str]: "torch._custom_op", "torch._custom_ops", "torch._decomp", + "torch._deploy", "torch._dispatch", "torch._dynamo", "torch._export", diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 83ce51fddd040..377faf327fc9d 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -127,6 +127,11 @@ def _lazy_forward(self, *args, **kwargs): forward = _lazy_forward + # TODO: we should handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + def __reduce_package__(self, exporter: PackageExporter): """ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 065cf82983e53..2e1a0963f53b6 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -30,6 +30,7 @@ __all__ = [ "reduce_graph_module", "reduce_package_graph_module", + "reduce_deploy_graph_module", "GraphModule", ] @@ -146,6 +147,18 @@ def reduce_package_graph_module( return _deserialize_graph_module(forward, body) +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + # We create a dummy class here because symbolic_trace pulls the forward() # function off of the class, rather than the instance. This class is used # in _deserialize_graph_module() below. @@ -840,6 +853,14 @@ def call_wrapped(self, *args, **kwargs): # Passing Tracer as argument allows subclasses extending fx.GraphModule # define their own Tracer (extending fx.Tracer). + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) def __reduce_package__(self, exporter: PackageExporter): dict_without_graph = self.__dict__.copy() From 4c18e85300e2157762b446d4831872987dcef39e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 23:17:39 +0000 Subject: [PATCH 360/457] Revert "[BE] Remove torch deploy | remove torch deploy specific files (#158290)" This reverts commit a6de309ca15cda6b2792fc74e82814dc8d2f9dd9. Reverted https://github.com/pytorch/pytorch/pull/158290 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78496147 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158288#issuecomment-3099826158)) --- docs/source/deploy.md | 8 + test/test_deploy.py | 43 +++++ tools/lldb/deploy_debugger.py | 38 +++++ torch/_deploy.py | 104 ++++++++++++ torch/csrc/deploy/README.md | 2 + torch/utils/_freeze.py | 292 ++++++++++++++++++++++++++++++++++ 6 files changed, 487 insertions(+) create mode 100644 docs/source/deploy.md create mode 100644 test/test_deploy.py create mode 100644 tools/lldb/deploy_debugger.py create mode 100644 torch/_deploy.py create mode 100644 torch/csrc/deploy/README.md create mode 100644 torch/utils/_freeze.py diff --git a/docs/source/deploy.md b/docs/source/deploy.md new file mode 100644 index 0000000000000..ef5131717bf7b --- /dev/null +++ b/docs/source/deploy.md @@ -0,0 +1,8 @@ +--- +orphan: true +--- + +# torch::deploy has been moved to pytorch/multipy + + +``torch::deploy`` has been moved to its new home at [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy). diff --git a/test/test_deploy.py b/test/test_deploy.py new file mode 100644 index 0000000000000..b852802c0c20f --- /dev/null +++ b/test/test_deploy.py @@ -0,0 +1,43 @@ +# Owner(s): ["oncall: package/deploy"] + +import textwrap +import types + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils._freeze import Freezer, PATH_MARKER + + +class TestFreezer(TestCase): + """Tests the freeze.py script""" + + def test_compile_string(self): + freezer = Freezer(True) + code_str = textwrap.dedent( + """ + class MyCls: + def __init__(self) -> None: + pass + """ + ) + co = freezer.compile_string(code_str) + num_co = 0 + + def verify_filename(co: types.CodeType): + nonlocal num_co + + if not isinstance(co, types.CodeType): + return + + self.assertEqual(PATH_MARKER, co.co_filename) + num_co += 1 + + for nested_co in co.co_consts: + verify_filename(nested_co) + + verify_filename(co) + # there is at least one nested code object besides the top level one + self.assertTrue(num_co >= 2) + + +if __name__ == "__main__": + run_tests() diff --git a/tools/lldb/deploy_debugger.py b/tools/lldb/deploy_debugger.py new file mode 100644 index 0000000000000..7a28c72a6caf2 --- /dev/null +++ b/tools/lldb/deploy_debugger.py @@ -0,0 +1,38 @@ +import lldb # type: ignore[import] + + +# load into lldb instance with: +# command script import tools/lldb/deploy_debugger.py + +target = lldb.debugger.GetSelectedTarget() +bp = target.BreakpointCreateByRegex("__deploy_register_code") +bp.SetScriptCallbackBody( + """\ +process = frame.thread.GetProcess() +target = process.target +symbol_addr = frame.module.FindSymbol("__deploy_module_info").GetStartAddress() +info_addr = symbol_addr.GetLoadAddress(target) +e = lldb.SBError() +ptr_size = 8 +str_addr = process.ReadPointerFromMemory(info_addr, e) +file_addr = process.ReadPointerFromMemory(info_addr + ptr_size, e) +file_size = process.ReadPointerFromMemory(info_addr + 2*ptr_size, e) +load_bias = process.ReadPointerFromMemory(info_addr + 3*ptr_size, e) +name = process.ReadCStringFromMemory(str_addr, 512, e) +r = process.ReadMemory(file_addr, file_size, e) +from tempfile import NamedTemporaryFile +from pathlib import Path +stem = Path(name).stem +with NamedTemporaryFile(prefix=stem, suffix='.so', delete=False) as tf: + tf.write(r) + print("torch_deploy registering debug information for ", tf.name) + cmd1 = f"target modules add {tf.name}" + # print(cmd1) + lldb.debugger.HandleCommand(cmd1) + cmd2 = f"target modules load -f {tf.name} -s {hex(load_bias)}" + # print(cmd2) + lldb.debugger.HandleCommand(cmd2) + +return False +""" +) diff --git a/torch/_deploy.py b/torch/_deploy.py new file mode 100644 index 0000000000000..0443a2447d00d --- /dev/null +++ b/torch/_deploy.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import io + +import torch +from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer +from torch.package._package_pickler import create_pickler +from torch.package._package_unpickler import PackageUnpickler +from torch.serialization import _maybe_decode_ascii + + +def _save_storages(importer, obj): + serialized_storages = [] + serialized_dtypes = [] + + importer = importer if isinstance(importer, torch.package.PackageImporter) else None + importers: Importer + if importer is not None: + importers = OrderedImporter(importer, sys_importer) + else: + importers = sys_importer + + def persistent_id(obj): + if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, we can + # remove this case + dtype = obj.dtype + else: + dtype = torch.uint8 + + serialized_storages.append(obj) + serialized_dtypes.append(dtype) + return ("storage", len(serialized_storages) - 1) + + if hasattr(obj, "__reduce_deploy__"): + if _serialized_reduces.get(id(obj)) is None: + _serialized_reduces[id(obj)] = ( + "reduce_deploy", + id(obj), + *obj.__reduce_deploy__(importers), + ) + return _serialized_reduces[id(obj)] + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = create_pickler(data_buf, importers) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + return ( + data_value, + serialized_storages, + serialized_dtypes, + importer.zip_reader if importer else None, + ) + + +def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + if typename == "storage": + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + storage = serialized_storages[data[0]] + dtype = serialized_dtypes[data[0]] + return torch.storage.TypedStorage( + wrap_storage=storage.untyped(), dtype=dtype + ) + + if typename == "reduce_deploy": + reduce_id, func, args = data + if reduce_id not in _loaded_reduces: + _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) + return _loaded_reduces[reduce_id] + + return None + + importer: Importer + if zip_reader is not None: + importer = OrderedImporter(_get_package(zip_reader), sys_importer) + else: + importer = sys_importer + + unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) + unpickler.persistent_load = persistent_load # type: ignore[method-assign] + result = _deploy_objects[id] = unpickler.load() + return result + + +def _get_package(zip_reader): + if zip_reader not in _raw_packages: + _raw_packages[zip_reader] = PackageImporter(zip_reader) + return _raw_packages[zip_reader] + + +_raw_packages: dict = {} +_deploy_objects: dict = {} +_serialized_reduces: dict = {} +_loaded_reduces: dict = {} diff --git a/torch/csrc/deploy/README.md b/torch/csrc/deploy/README.md new file mode 100644 index 0000000000000..2d40ca8361ff4 --- /dev/null +++ b/torch/csrc/deploy/README.md @@ -0,0 +1,2 @@ +# torch::deploy has been moved to pytorch/multipy +Please check out [https://github.com/pytorch/multipy](https://github.com/pytorch/multipy) to find the new home for torch::deploy. diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py new file mode 100644 index 0000000000000..8696065adb9f9 --- /dev/null +++ b/torch/utils/_freeze.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +""" +Freeze Python packages. + + + + +Freezing makes it possible to ship arbitrary Python modules as part of a C++ +library. The Python source of the module is compiled to bytecode and written +to `.c` files, to be imported by Python's built-in FrozenImporter. + +In a normal Python installation, FrozenImporter is only used to bootstrap the +initialization of the import machinery. Python's importers are defined in +Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be +retrieved before any importers are available. Freezing the module bytecode +resolves this circular dependency. + +This script will freeze the Python standard library. It produces two things: +- Bytecode files: A set of `.c` that define C variables containing Python bytecode. +- Main file: A `main.c` file listing all of these modules in the right form to be + consumed by FrozenImporter. + +The library that wishes to these modules make them available to the local +Python instance by extending `PyImport_FrozenModules` appropriately (see +https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules). +""" + +import argparse +import functools +import itertools +import marshal +import os +import types +from dataclasses import dataclass +from pathlib import Path + + +PATH_MARKER = "" +MAIN_INCLUDES = """#include + +""" + +MAIN_PREFIX_TEMPLATE = """ +// Compiled standard library modules. These should be appended to the existing +// `PyImport_FrozenModules` that ships with CPython. +struct _frozen {}[] = {{ +""" + +FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules") + +MAIN_SUFFIX = """\ + {0, 0, 0} /* sentinel */ +}; +""" + +# Exclude some standard library modules to: +# 1. Slim down the final frozen lib. +# 2. Remove functionality we don't want to support. +DENY_LIST = [ + # Interface to unix databases + "dbm", + # ncurses bindings (terminal interfaces) + "curses", + # Tcl/Tk GUI + "tkinter", + "tkinter", + # Tests for the standard library + "test", + "tests", + "idle_test", + "__phello__.foo.py", + # importlib frozen modules. These are already baked into CPython. + "_bootstrap.py", + "_bootstrap_external.py", +] + +NUM_BYTECODE_FILES = 5 + + +def indent_msg(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + args[0].indent += 1 + ret = fn(*args, **kwargs) + args[0].indent -= 1 + return ret + + return wrapper + + +@dataclass +class FrozenModule: + # The fully qualified module name, e.g. 'foo.bar.baz' + module_name: str + # The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz' + c_name: str + # The size of the C variable. Negative if this module is a package. + size: int + # The frozen bytecode + bytecode: bytes + + +class Freezer: + def __init__(self, verbose: bool): + self.frozen_modules: list[FrozenModule] = [] + self.indent: int = 0 + self.verbose: bool = verbose + + def msg(self, path: Path, code: str): + if not self.verbose: + return + # P: package dir + # F: python file + # S: skipped (not a package dir) + # X: skipped (deny-listed) + # N: skipped (not a python file) + print(" " * self.indent, end="") + print(f"{code} {path}") + + def write_bytecode(self, install_root): + """ + Write the `.c` files containing the frozen bytecode. + + Shared frozen modules evenly across the files. + """ + bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)] + bytecode_files = [ + open(os.path.join(install_root, name), "w") for name in bytecode_file_names + ] + it = itertools.cycle(bytecode_files) + for m in self.frozen_modules: + self.write_frozen(m, next(it)) + + for f in bytecode_files: + f.close() + + def write_main(self, install_root, oss, symbol_name): + """Write the `main.c` file containing a table enumerating all the frozen modules.""" + with open(os.path.join(install_root, "main.c"), "w") as outfp: + outfp.write(MAIN_INCLUDES) + for m in self.frozen_modules: + outfp.write(f"extern unsigned char {m.c_name}[];\n") + + outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name)) + for m in self.frozen_modules: + outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n') + outfp.write(MAIN_SUFFIX) + if oss: + outfp.write(FAKE_PREFIX) + outfp.write(MAIN_SUFFIX) + + def write_frozen(self, m: FrozenModule, outfp): + """Write a single frozen module's bytecode out to a C variable.""" + outfp.write(f"unsigned char {m.c_name}[] = {{") + for i in range(0, len(m.bytecode), 16): + outfp.write("\n\t") + for c in bytes(m.bytecode[i : i + 16]): + outfp.write(f"{c:d},") + outfp.write("\n};\n") + + def compile_path(self, path: Path, top_package_path: Path): + """Entry point for compiling a Path object.""" + if path.is_dir(): + self.compile_package(path, top_package_path) + else: + self.compile_file(path, top_package_path) + + @indent_msg + def compile_package(self, path: Path, top_package_path: Path): + """Compile all the files within a Python package dir.""" + assert path.is_dir() + if path.name in DENY_LIST: + self.msg(path, "X") + return + + # Python packages are directories that have __init__.py in them. + is_package_dir = any(child.name == "__init__.py" for child in path.iterdir()) + if not is_package_dir: + self.msg(path, "S") + return + + self.msg(path, "P") + # Recursively compile all children in this dir + for child in path.iterdir(): + self.compile_path(child, top_package_path) + + def get_module_qualname(self, file_path: Path, top_package_path: Path) -> list[str]: + # `path` looks like 'Lib/foo/bar/baz.py' + + # chop off 'Lib/' to get something that represents a Python module hierarchy. + # e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz' + normalized_path = file_path.relative_to(top_package_path.parent) + + if normalized_path.name == "__init__.py": + # Special handling for `__init__.py`. In this case, this file + # specifies that the containing directory should be treated as a package. + # For 'foo/bar/baz/__init__.py': + # - The module name is 'baz' + module_basename = normalized_path.parent.name + # - The parent is foo.bar (need to shave off the 'baz') + module_parent = normalized_path.parent.parent.parts + else: + module_basename = normalized_path.stem + module_parent = normalized_path.parent.parts + return list(module_parent) + [module_basename] + + def compile_string(self, file_content: str) -> types.CodeType: + # instead of passing in the real build time path to 'compile', we + # pass in a marker instead. This prevents the build time path being + # leaked to runtime. That path may not be available at runtime. + # Setting the path to a mark make sure it's a hard error rather + # than a flaky error when inspect module tries to retrieve python source + # code during torchscripting. + path_marker = PATH_MARKER + return compile(file_content, path_marker, "exec") + + @indent_msg + def compile_file(self, path: Path, top_package_path: Path): + """ + Compile a Python source file to frozen bytecode. + + Append the result to `self.frozen_modules`. + """ + assert path.is_file() + if path.suffix != ".py": + self.msg(path, "N") + return + + if path.name in DENY_LIST: + self.msg(path, "X") + return + + self.msg(path, "F") + module_qualname = self.get_module_qualname(path, top_package_path) + module_mangled_name = "__".join(module_qualname) + c_name = "M_" + module_mangled_name + + with open(path) as src_file: + co = self.compile_string(src_file.read()) + + bytecode = marshal.dumps(co) + size = len(bytecode) + if path.name == "__init__.py": + # Python packages are signified by negative size. + size = -size + self.frozen_modules.append( + FrozenModule(".".join(module_qualname), c_name, size, bytecode) + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compile py source") + parser.add_argument("paths", nargs="*", help="Paths to freeze.") + parser.add_argument("--verbose", action="store_true", help="Print debug logs") + parser.add_argument( + "--install-dir", "--install_dir", help="Root directory for all output files" + ) + parser.add_argument( + "--oss", + action="store_true", + help="If it's OSS build, add a fake _PyImport_FrozenModules", + ) + parser.add_argument( + "--symbol-name", + "--symbol_name", + help="The name of the frozen module array symbol to generate", + default="_PyImport_FrozenModules_torch", + ) + + args = parser.parse_args() + + f = Freezer(args.verbose) + + for p in args.paths: + path = Path(p) + if path.is_dir() and not Path.exists(path / "__init__.py"): + # this 'top level path p' is a standard directory containing modules, + # not a module itself + # each 'mod' could be a dir containing __init__.py or .py file + # NB: sorted to make sure this is deterministic + for mod in sorted(path.glob("*")): + f.compile_path(mod, mod) + else: + f.compile_path(path, path) + + f.write_bytecode(args.install_dir) + f.write_main(args.install_dir, args.oss, args.symbol_name) + + +if __name__ == "__main__": + main() # pragma: no cover From ee5a434f8ce96a8f13b8c655356222137483c4db Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Jul 2025 23:17:39 +0000 Subject: [PATCH 361/457] Revert "[BE] remove torch deploy - conditionals (#158288)" This reverts commit 1a4268b8113d5160d71225bab980f03c2318a0a4. Reverted https://github.com/pytorch/pytorch/pull/158288 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78496147 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158288#issuecomment-3099826158)) --- test/test_custom_ops.py | 56 +++ test/test_sparse_csr.py | 8 +- torch/__init__.py | 47 ++- .../_dynamo/_trace_wrapped_higher_order_op.py | 81 ++-- torch/_dynamo/trace_rules.py | 1 + torch/_inductor/test_operators.py | 43 ++- torch/_library/custom_ops.py | 4 + torch/_library/utils.py | 10 + torch/_ops.py | 3 + torch/_utils_internal.py | 12 +- torch/csrc/lazy/python/init.cpp | 4 + torch/csrc/utils/python_dispatch.cpp | 9 + torch/cuda/__init__.py | 3 + torch/distributed/_functional_collectives.py | 132 ++++--- torch/distributed/_tools/fake_collectives.py | 7 +- .../fsdp/_fully_shard/_fsdp_common.py | 36 +- .../fsdp/_fully_shard/_fsdp_param.py | 3 +- torch/distributed/tensor/_collective_utils.py | 31 +- torch/library.py | 21 + torch/utils/__init__.py | 8 +- torch/utils/_import_utils.py | 8 +- torch/utils/collect_env.py | 362 ++++++++---------- 22 files changed, 507 insertions(+), 382 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index d4f5c2a7c0523..1f3e670c15396 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -544,6 +544,62 @@ def test_assert_raises_regex(self, device): class TestCustomOp(CustomOpTestCaseBase): test_ns = "_test_custom_op" + def test_deploy_interaction(self): + # run in a different process to avoid parallel issues when we monkeypatch torch._running_with_deploy + script = """ +import torch +torch._running_with_deploy = lambda: True + +# creating the library is a no-op, so you can DEF multiple times +m1 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 +m2 = torch.library.Library("mylib4392", "DEF") # noqa: TOR901 + +m = torch.library.Library("aten", "FRAGMENT") # noqa: TOR901 + +# define is a no-op +m.define("foobarbaz9996(Tensor x) -> Tensor") +assert not hasattr(torch.ops.aten, "foobarbaz9996"), "m.define should have been a noop" + +def sin_override(x): + raise AssertionError("m.impl should have been a noop") + +# impl is a no-op +m.impl("sin", sin_override, "CompositeImplicitAutograd") +x = torch.randn(3) +y = torch.sin(x) + +# should be a no-op +@torch.library.custom_op("mylib::foobar", mutates_args={}) +def foobar(x: torch.Tensor) -> torch.Tensor: + return x.sin() + +# should be a no-op +@foobar.register_fake +def _(x): + return torch.empty_like(x) + +# should be a no-op +m2.define("foobarbaz9996(Tensor x) -> Tensor") + +# should be a no-op +@torch.library.register_fake("mylib4392::foobarbaz9996") +def _(x): + return torch.empty_like(x) + """ + script = script.strip() + env = os.environ.copy() + try: + subprocess.check_output( + [sys.executable, "-c", script], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + env=env, + ) + except subprocess.CalledProcessError as e: + self.fail(msg=("Subprocess exception:\n" + e.output.decode("utf-8"))) + @requires_compile def test_functionalize_error(self): with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 8fb490e1b5bc7..cc313c586a090 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3603,8 +3603,8 @@ def test_triton_bsr_softmax(self, device, dtype): @onlyCUDA @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) - @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU), - "Skipped for internal with remote GPUs") + @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), + "Skipped for deploy and internal with remote GPUs") def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): from functools import partial from torch.sparse._triton_ops import bsr_dense_mm @@ -3680,8 +3680,8 @@ def kernel_impl(*args, **kwargs): @onlyCUDA @dtypes(torch.half) - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, - "Skipped for internal with remote GPUs") + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), + "Skipped for deploy and internal with remote GPUs") def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/__init__.py b/torch/__init__.py index f124d1a5a1d6c..99cb83db84b81 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -34,10 +34,20 @@ ) from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs +from . import version + if TYPE_CHECKING: from .types import Device, IntLikeType + +# multipy/deploy is setting this import before importing torch, this is the most # codespell:ignore multipy +# reliable way we have to detect if we're running within deploy. +# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 # codespell:ignore multipy # noqa: B950 +def _running_with_deploy() -> builtins.bool: + return sys.modules.get("torch._meta_registrations", None) is object + + from torch._utils import ( _functionalize_sync as _sync, _import_dotted_name, @@ -50,9 +60,14 @@ USE_GLOBAL_DEPS, USE_RTLD_GLOBAL_WITH_LIBTORCH, ) -from torch.torch_version import __version__ as __version__ +# TODO(torch_deploy) figure out how to freeze version.py in fbcode build +if _running_with_deploy(): + __version__ = "torch-deploy-1.8" +else: + from torch.torch_version import __version__ as __version__ + __all__ = [ "BoolStorage", "BoolTensor", @@ -302,7 +317,7 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: # See Note [Global dependencies] def _load_global_deps() -> None: - if platform.system() == "Windows": + if _running_with_deploy() or platform.system() == "Windows": return # Determine the file extension based on the platform @@ -366,7 +381,7 @@ def _load_global_deps() -> None: if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( - platform.system() != "Windows" + _running_with_deploy() or platform.system() != "Windows" ): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: @@ -2067,7 +2082,7 @@ def _dtype(self): # Shared memory manager needs to know the exact location of manager executable def _manager_path(): - if platform.system() == "Windows": + if _running_with_deploy() or platform.system() == "Windows": return b"" path = get_file_path("torch", "bin", "torch_shm_manager") prepare_multiprocessing_environment(get_file_path("torch")) @@ -2672,21 +2687,21 @@ def _register_device_module(device_type, module): # Register MPS specific decomps torch.backends.mps._init() -from torch import compiler as compiler - +if not _running_with_deploy(): + from torch import compiler as compiler -class _TritonLibrary: - lib = torch.library.Library("triton", "DEF") - ops_table: dict[tuple[str, str], _Callable] = {} + class _TritonLibrary: + lib = torch.library.Library("triton", "DEF") + ops_table: dict[tuple[str, str], _Callable] = {} - @classmethod - def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): - if (op_key, dispatch_key) not in cls.ops_table: - cls.lib.define(full_schema) - cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) - cls.ops_table[(op_key, dispatch_key)] = op_impl + @classmethod + def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): + if (op_key, dispatch_key) not in cls.ops_table: + cls.lib.define(full_schema) + cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) + cls.ops_table[(op_key, dispatch_key)] = op_impl - return cls.ops_table[(op_key, dispatch_key)] + return cls.ops_table[(op_key, dispatch_key)] # Deprecated attributes diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 17b664fc5e0ed..8fab0b2005491 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -49,46 +49,47 @@ __all__ = ["trace_wrapped"] -@torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] -def zeros_and_scatter( - shape: list[int], - indices: list[Tensor], - vals: Tensor, -) -> Tensor: - """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" - grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) - return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) - - -@zeros_and_scatter.register_fake # type: ignore[misc] -def _( - shape: list[int], - indices: list[Tensor], - vals: Tensor, -) -> Tensor: - return vals.new_empty(shape) - - -@zeros_and_scatter.register_vmap # type: ignore[misc] -def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] - """The batching rule is special in that it returns a tensor that is not batched""" - indices_indims = indims[1] - expanded_indices = [] - for idx, idx_indim in zip(indices, indices_indims): - # The index is not a being batched, we should unsqueeze and expand to val - if idx_indim is None: - expanded_indices.append(idx.expand(value.shape)) - else: - # the index is being part of the vmap batch, it should be the same size as val - assert idx.shape == value.shape - expanded_indices.append(idx) - - out = torch.ops.flex_lib.zeros_and_scatter( - shape, - expanded_indices, - value, - ) - return out, None +if not torch._running_with_deploy(): + # torch.library.custom_op does not work with torch.deploy/multipy # codespell:ignore + + @torch.library.custom_op("flex_lib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] + def zeros_and_scatter( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) + + @zeros_and_scatter.register_fake # type: ignore[misc] + def _( + shape: list[int], + indices: list[Tensor], + vals: Tensor, + ) -> Tensor: + return vals.new_empty(shape) + + @zeros_and_scatter.register_vmap # type: ignore[misc] + def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.flex_lib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None class ModIndex(torch.autograd.Function): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 3889771334d4b..4ff88a25bce3d 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2409,6 +2409,7 @@ "torch._lowrank.svd_lowrank", "torch._preload_cuda_deps", "torch._register_device_module", + "torch._running_with_deploy", "torch._utils._dummy_type", "torch._utils._flatten_dense_tensors", "torch._utils._unflatten_dense_tensors", diff --git a/torch/_inductor/test_operators.py b/torch/_inductor/test_operators.py index d3d2705f8c788..bf49f3f5d04a1 100644 --- a/torch/_inductor/test_operators.py +++ b/torch/_inductor/test_operators.py @@ -5,24 +5,25 @@ from torch.autograd import Function -_test_lib_def = torch.library.Library("_inductor_test", "DEF") -_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) - -_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") -for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): - _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) - - -class Realize(Function): - @staticmethod - def forward(ctx: object, x: Tensor) -> Tensor: - return torch.ops._inductor_test.realize(x) - - @staticmethod - # types need to stay consistent with _SingleLevelFunction - def backward(ctx: Any, *grad_output: Any) -> Any: - return grad_output[0] - - -def realize(x: Tensor) -> Tensor: - return Realize.apply(x) +if not torch._running_with_deploy(): + _test_lib_def = torch.library.Library("_inductor_test", "DEF") + _test_lib_def.define( + "realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag + ) + + _test_lib_impl = torch.library.Library("_inductor_test", "IMPL") + for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + class Realize(Function): + @staticmethod + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) + + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 1d8d0fc5377b1..547d305c47afd 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -595,6 +595,10 @@ def register_autograd( self._setup_context_fn = setup_context def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: + if torch._running_with_deploy(): + utils.warn_deploy(stacklevel=5) + return + lib = self._lib schema_str = self._name + self._schema cpp_schema = _C.parse_schema(schema_str) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 9403185204520..17e128bdbe0f3 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -2,6 +2,7 @@ import dataclasses import inspect import sys +import warnings from collections.abc import Iterable, Iterator from typing import Any, Callable, Union @@ -11,6 +12,15 @@ from torch._ops import OpOverload +def warn_deploy(stacklevel=3): + warnings.warn( + "Python torch.library APIs do nothing under torch::deploy (multipy). " # codespell:ignore multipy + "Please instead use C++ custom operator registration APIs.", + RuntimeWarning, + stacklevel=stacklevel, + ) + + @dataclasses.dataclass class Kernel: """Models a (function, source location)""" diff --git a/torch/_ops.py b/torch/_ops.py index 83a5dc0e57a5e..fecfebaeaa53b 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1478,6 +1478,9 @@ def load_library(self, path): Args: path (str): A path to a shared library to load. """ + if torch._running_with_deploy(): + return + path = _utils_internal.resolve_library_path(path) with dl_open_guard(): # Import the shared library into the process, thus running its diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index e067a587497b1..1833b918e180e 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -33,10 +33,16 @@ # use is the FB build environment, where this source file is replaced # by an equivalent. -if os.path.basename(os.path.dirname(__file__)) == "shared": - torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +if torch._running_with_deploy(): + # __file__ is meaningless in the context of frozen torch used in torch deploy. + # setting empty torch_parent should allow below functions to operate without crashing, + # but it's unclear if there is a valid use case for them in the context of deploy. + torch_parent = "" else: - torch_parent = os.path.dirname(os.path.dirname(__file__)) + if os.path.basename(os.path.dirname(__file__)) == "shared": + torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + else: + torch_parent = os.path.dirname(os.path.dirname(__file__)) def get_file_path(*path_components: str) -> str: diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 4807aa6a4c7d1..f2b14cbfd7bb4 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -331,9 +331,13 @@ void initLazyBindings(PyObject* module) { // So far this problem has only been observed internally, so we will just // block it off there. +#if !(defined(USE_DEPLOY)) + // When libtorch_python is loaded, we register the python frame getter // otherwise, debug util simply omits python frames GetPythonFramesFunction() = GetPythonFrames; + +#endif // USE_DEPLOY } } // namespace torch::lazy diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index b2b0e848a7e79..34fbfec49c919 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -187,6 +187,15 @@ class PythonKernelHolder : public c10::OperatorKernel { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; + // Jan 2024: We're slated to get rid of multipy, // codespell:ignore multipy + // so stop forcing hermetic mode unconditionally in all situations when + // you're using multipy. // codespell:ignore multipy + // Eventually just delete this entirely. (Note that you may break + // multipy anyway this way with dispatcher // codespell:ignore multipy + // registered functions that require hermetic to be off.) +#if defined(USE_DEPLOY) + EnableHermeticPyObject g2; +#endif auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto func = py::reinterpret_borrow(func_.ptr(getPyInterpreter())); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 01bc4d73a4595..6a2d62bd424cb 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1693,6 +1693,9 @@ def __call__(self, *args, **kwargs): def _register_triton_kernels(): + if torch._running_with_deploy(): + return + @_WrappedTritonKernel def kernel_impl(*args, **kwargs): from torch.sparse._triton_ops import bsr_dense_mm diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 73cdcf4217895..0ffae8a9c9fe3 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -19,17 +19,23 @@ from torch.utils._pytree import tree_map_only # type: ignore[no-redef] -try: - from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling -except Exception: - warnings.warn( - "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" - ) +if torch._running_with_deploy(): - def is_torchdynamo_compiling(): # type: ignore[misc] - return False + def is_torchdynamo_compiling(): + """Can't import torchdynamo in torchdeploy builds currently.""" return False +else: + try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling + except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly" + ) + + def is_torchdynamo_compiling(): + return False + """ New traceable, functional collectives. @@ -981,58 +987,66 @@ def _reduce_scatter_tensor_coalesced_native_meta( ] -# Library MUST be defined at module scope or it doesn't work -lib_impl = torch.library.Library("_c10d_functional", "IMPL") -lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") -lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") -lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") -lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") -lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") -lib_impl.impl( - "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" -) -lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") -lib_impl.impl( - "all_gather_into_tensor_coalesced", - _all_gather_into_tensor_coalesced_native_meta, - "Meta", -) -lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") -lib_impl.impl( - "reduce_scatter_tensor_coalesced", - _reduce_scatter_tensor_coalesced_native_meta, - "Meta", -) -lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") -lib_impl.impl("broadcast", _broadcast_meta, "Meta") -lib_impl.impl("broadcast_", _broadcast__meta, "Meta") - -# mark these ops has side effect so that they won't be removed by DCE -torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) -torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) - -# Register legacy ops for backward compatibility -# TODO(yifu): remove these in functional collective beta release -legacy_lib = torch.library.Library("c10d_functional", "DEF") -legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") -ops_defs = [ - "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "wait_tensor(Tensor self) -> Tensor", - "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", - "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", - "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", - "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 -] +if not torch._running_with_deploy(): + # Library MUST be defined at module scope or it doesn't work + # Creating a "DEF" Library always crashes torch::deploy so we create our + # Library instances here guarded against running inside it + lib_impl = torch.library.Library("_c10d_functional", "IMPL") + lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") + lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") + lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") + lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") + lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" + ) + lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") + lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") + lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", + ) + lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") + lib_impl.impl("broadcast", _broadcast_meta, "Meta") + lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + + # mark these ops has side effect so that they won't be removed by DCE + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) + torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) + + # Register legacy ops for backward compatibility + # TODO(yifu): remove these in functional collective beta release + legacy_lib = torch.library.Library("c10d_functional", "DEF") + legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") + ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 + ] + + my_module = sys.modules[__name__] + for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") -my_module = sys.modules[__name__] -for op_def in ops_defs: - op_name = op_def[0 : op_def.index("(")] - backend_impl = getattr(fun_col_impl, f"_{op_name}") - legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) - legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") +else: + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) """ diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 3b201b395334b..f6cb23a06b671 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -63,9 +63,10 @@ def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), } -lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 -for op, meta_func in _META_FUNCTIONS.items(): - lib_impl.impl(op, meta_func, "Meta") +if not torch._running_with_deploy(): + lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 + for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") # List of collective operation functions including functional collectives # Note: The following collectives might be deprecated soon hence not adding them diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index b599f48d77d1d..fdcf32e22a338 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -15,24 +15,32 @@ _compiled_autograd_enabled: bool = False +if torch._running_with_deploy(): -def detect_compiled_autograd(): - assert not torch.compiler.is_compiling(), ( - "`detect_compiled_autograd()` is designed to be called in eager mode" - ) - global _compiled_autograd_enabled - import torch._dynamo.compiled_autograd as ca + def detect_compiled_autograd(): + pass - _compiled_autograd_enabled = ( - ca.compiled_autograd_enabled - or ca.compiled_autograd_enabled_force_eager - or ca.in_compiled_autograd_region - ) + def compiled_autograd_enabled(): + return False +else: -def compiled_autograd_enabled(): - global _compiled_autograd_enabled - return _compiled_autograd_enabled + def detect_compiled_autograd(): + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) + + def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index b7c8f4ea7c78a..7649c32ec1c0e 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -140,7 +140,8 @@ def copy__functionalize(tensor, data): torch.ops.fsdp.copy_.default(tensor_inner, data_inner) -torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) +if not torch._running_with_deploy(): + torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) class ShardedState(Enum): diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 4fce6fea538a6..36316b2f0567a 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -25,17 +25,26 @@ logger = logging.getLogger(__name__) -@torch.library.register_fake("_dtensor::shard_dim_alltoall") -def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): - group_size = _get_group_size_by_name(group_name) - stacked_list = [torch.empty_like(input) for _ in range(group_size)] - group = _resolve_process_group(group_name) - group_rank = get_group_rank(group, get_rank()) - - return ( - torch.cat(stacked_list, dim=gather_dim) - .chunk(group_size, dim=shard_dim)[group_rank] - .contiguous() +if not torch._running_with_deploy(): + + @torch.library.register_fake("_dtensor::shard_dim_alltoall") + def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() + ) + +else: + import warnings + + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." ) diff --git a/torch/library.py b/torch/library.py index 23a7acf1662c5..a30cdb9bb48ac 100644 --- a/torch/library.py +++ b/torch/library.py @@ -102,6 +102,9 @@ def __init__(self, ns, kind, dispatch_key=""): ns, " is a reserved namespace. Please try creating a library with another name.", ) + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return frame = traceback.extract_stack(limit=3)[0] filename, lineno = frame.filename, frame.lineno @@ -153,6 +156,9 @@ def define(self, schema, alias_analysis="", *, tags=()): >>> my_lib = Library("mylib", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid # AliasAnalysis type in C++ @@ -185,6 +191,9 @@ def define(self, schema, alias_analysis="", *, tags=()): def _register_fake(self, op_name, fn, _stacklevel=1, *, allow_override=False): r"""Registers the fake impl for an operator defined in the library.""" + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return source = torch._library.utils.get_source(_stacklevel + 1) frame = sys._getframe(_stacklevel) @@ -228,6 +237,9 @@ def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): If it is a TorchDispatchMode, we expect fn to have the following signature: (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return qualname = f"{self.ns}::{op_name}" entry = torch._library.simple_registry.singleton.find(qualname) @@ -247,6 +259,9 @@ def _impl_with_aoti_compile(self, op_name, dispatch_key=""): >>> my_lib = Library("aten", "IMPL") >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return if dispatch_key == "": dispatch_key = self.dispatch_key @@ -309,6 +324,9 @@ def impl( >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return if not callable(fn): raise TypeError( @@ -391,6 +409,9 @@ def fallback(self, fn, dispatch_key="", *, with_keyset=False): >>> # ... >>> my_lib.fallback(fallback_kernel, "Autocast") """ + if torch._running_with_deploy(): + _library.utils.warn_deploy() + return if dispatch_key == "": dispatch_key = self.dispatch_key diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 1c3ec15790063..23188bba9b800 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -29,7 +29,13 @@ def set_module(obj, mod): obj.__module__ = mod -cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "share", "cmake") +if torch._running_with_deploy(): + # not valid inside torch_deploy interpreter, no paths exists for frozen modules + cmake_prefix_path = None +else: + cmake_prefix_path = _osp.join( + _osp.dirname(_osp.dirname(__file__)), "share", "cmake" + ) def swap_tensors(t1, t2): diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index 240f92acacb9d..dc2d7d4f0382c 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -3,6 +3,8 @@ from types import ModuleType from typing import Optional +import torch + def _check_module_exists(name: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** @@ -20,7 +22,11 @@ def _check_module_exists(name: str) -> bool: @functools.lru_cache def dill_available() -> bool: - return _check_module_exists("dill") + return ( + _check_module_exists("dill") + # dill fails to import under torchdeploy + and not torch._running_with_deploy() + ) @functools.lru_cache diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index c6473220bc00a..9bb80c65076b8 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -6,53 +6,49 @@ import datetime import json import locale -import os import re import subprocess import sys -from collections import namedtuple +import os from typing import cast as _cast +from collections import namedtuple try: import torch - TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information -SystemEnv = namedtuple( - "SystemEnv", - [ - "torch_version", - "is_debug_build", - "cuda_compiled_version", - "gcc_version", - "clang_version", - "cmake_version", - "os", - "libc_version", - "python_version", - "python_platform", - "is_cuda_available", - "cuda_runtime_version", - "cuda_module_loading", - "nvidia_driver_version", - "nvidia_gpu_models", - "cudnn_version", - "is_xpu_available", - "pip_version", # 'pip' or 'pip3' - "pip_packages", - "conda_packages", - "hip_compiled_version", - "hip_runtime_version", - "miopen_runtime_version", - "caching_allocator_config", - "is_xnnpack_available", - "cpu_info", - ], -) +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'is_xpu_available', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', +]) COMMON_PATTERNS = [ "torch", @@ -120,13 +116,12 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False - p = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell - ) + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=shell) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == "win32": - enc = "oem" + if get_platform() == 'win32': + enc = 'oem' else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) @@ -152,19 +147,18 @@ def run_and_parse_first_match(run_lambda, command, regex): return None return match.group(1) - def run_and_return_first_line(run_lambda, command): """Run command using run_lambda and returns first line if output is not empty.""" rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split("\n")[0] + return out.split('\n')[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - conda = os.environ.get("CONDA_EXE", "conda") + conda = os.environ.get('CONDA_EXE', 'conda') out = run_and_read_all(run_lambda, "{} list".format(conda)) if out is None: return out @@ -172,40 +166,32 @@ def get_conda_packages(run_lambda, patterns=None): return "\n".join( line for line in out.splitlines() - if not line.startswith("#") and any(name in line for name in patterns) + if not line.startswith("#") + and any(name in line for name in patterns) ) - def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") - + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') def get_clang_version(run_lambda): - return run_and_parse_first_match( - run_lambda, "clang --version", r"clang version (.*)" - ) + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') def get_nvidia_driver_version(run_lambda): - if get_platform() == "darwin": - cmd = "kextstat | grep -i cuda" - return run_and_parse_first_match( - run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" - ) + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') def get_gpu_info(run_lambda): - if get_platform() == "darwin" or ( - TORCH_AVAILABLE - and hasattr(torch.version, "hip") - and torch.version.hip is not None - ): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -218,42 +204,42 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r" \(UUID: .+?\)") - rc, out, _ = run_lambda(smi + " -L") + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, "", out) + return re.sub(uuid_regex, '', out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == "win32": - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") - where_cmd = os.path.join(system_root, "System32", "where") + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == "darwin": + elif get_platform() == 'darwin': # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get("CUDNN_LIBRARY") + l = os.environ.get('CUDNN_LIBRARY') if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split("\n"): + for fn in out.split('\n'): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -263,20 +249,18 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = "\n".join(files) - return "Probably one of the following:\n{}".format(result) + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux - smi = "nvidia-smi" - if get_platform() == "win32": - system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") - program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") - legacy_path = os.path.join( - program_files_root, "NVIDIA Corporation", "NVSMI", smi - ) - new_path = os.path.join(system_root, "System32", smi) + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -427,9 +411,7 @@ def get_intel_gpu_detected(run_lambda): if device_count == 0: return "N/A" - devices = [ - f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count) - ] + devices = [f"* [{i}] {torch.xpu.get_device_properties(i)}" for i in range(device_count)] return "\n".join(devices) @@ -508,12 +490,11 @@ def get_intel_gpu_detected(run_lambda): # ProcessorType=3 # Revision=27142 - def get_cpu_info(run_lambda): - rc, out, err = 0, "", "" - if get_platform() == "linux": - rc, out, err = run_lambda("lscpu") - elif get_platform() == "win32": + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': rc, out, err = run_lambda( 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ @@ -533,9 +514,9 @@ def get_cpu_info(run_lambda): lst.append(out) lst.append(str(e)) out = "\n".join(lst) - elif get_platform() == "darwin": + elif get_platform() == 'darwin': rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = "None" + cpu_info = 'None' if rc == 0: cpu_info = out else: @@ -544,20 +525,20 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith("linux"): - return "linux" - elif sys.platform.startswith("win32"): - return "win32" - elif sys.platform.startswith("cygwin"): - return "cygwin" - elif sys.platform.startswith("darwin"): - return "darwin" + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') def get_windows_version(run_lambda): @@ -575,43 +556,39 @@ def get_windows_version(run_lambda): def get_lsb_version(run_lambda): - return run_and_parse_first_match( - run_lambda, "lsb_release -a", r"Description:\t(.*)" - ) + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') def check_release_file(run_lambda): - return run_and_parse_first_match( - run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' - ) + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') def get_os(run_lambda): from platform import machine - platform = get_platform() if platform in ["win32", "cygwin"]: return get_windows_version(run_lambda) - if platform == "darwin": + if platform == 'darwin': version = get_mac_version(run_lambda) if version is None: return None - return "macOS {} ({})".format(version, machine()) + return 'macOS {} ({})'.format(version, machine()) - if platform == "linux": + if platform == 'linux': # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return "{} ({})".format(desc, machine()) + return '{} ({})'.format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return "{} ({})".format(desc, machine()) + return '{} ({})'.format(desc, machine()) - return "{} ({})".format(platform, machine()) + return '{} ({})'.format(platform, machine()) # Unknown platform return platform @@ -619,16 +596,14 @@ def get_os(run_lambda): def get_python_platform(): import platform - return platform.platform() def get_libc_version(): import platform - - if get_platform() != "linux": - return "N/A" - return "-".join(platform.libc_ver()) + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) def get_pip_packages(run_lambda, patterns=None): @@ -636,35 +611,35 @@ def get_pip_packages(run_lambda, patterns=None): if patterns is None: patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + ONEAPI_PATTERNS - pip_version = "pip3" if sys.version_info.major == 3 else "pip" + pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' - os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" + os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' # People generally have pip as `pip` or `pip3` # But here it is invoked as `python -mpip` - out = run_and_read_all( - run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"] - ) + out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) if out is None: return pip_version, out - filtered_out = "\n".join( - line for line in out.splitlines() if any(name in line for name in patterns) + filtered_out = '\n'.join( + line + for line in out.splitlines() + if any(name in line for name in patterns) ) return pip_version, filtered_out def get_cachingallocator_config(): - ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') if not ca_config: - ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "") + ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get("CUDA_MODULE_LOADING", "") + config = os.environ.get('CUDA_MODULE_LOADING', '') return config else: return "N/A" @@ -673,12 +648,10 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack - return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" - def get_env_info(): """ Collects environment information to aid in debugging. @@ -705,31 +678,26 @@ def get_env_info(): cuda_version_str = torch.version.cuda xpu_available_str = str(torch.xpu.is_available()) if torch.xpu.is_available(): - xpu_available_str = ( - f"{xpu_available_str}\n" - + f"XPU used to build PyTorch: {torch.version.xpu}\n" - + f"Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n" - + f"Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n" - + f"Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}" - ) - if ( - not hasattr(torch.version, "hip") or torch.version.hip is None - ): # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + xpu_available_str = f'{xpu_available_str}\n' + \ + f'XPU used to build PyTorch: {torch.version.xpu}\n' + \ + f'Intel GPU driver version:\n{get_intel_gpu_driver_version(run_lambda)}\n' + \ + f'Intel GPU models onboard:\n{get_intel_gpu_onboard(run_lambda)}\n' + \ + f'Intel GPU models detected:\n{get_intel_gpu_detected(run_lambda)}' + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' else: # HIP version - def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else "N/A" + return _lst[0] if _lst else 'N/A' - cfg = torch._C._show_config().split("\n") - hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") - miopen_runtime_version = get_version_or_na(cfg, "MIOpen") - cuda_version_str = "N/A" + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = "N/A" # type: ignore[assignment] - hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" + version_str = debug_mode_str = cuda_available_str = cuda_version_str = xpu_available_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' sys_version = sys.version.replace("\n", " ") @@ -738,9 +706,7 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version="{} ({}-bit runtime)".format( - sys_version, sys.maxsize.bit_length() + 1 - ), + python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -766,7 +732,6 @@ def get_version_or_na(cfg, prefix): cpu_info=get_cpu_info(run_lambda), ) - env_info_fmt = """ PyTorch version: {torch_version} Is debug build: {is_debug_build} @@ -802,14 +767,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): - def replace_nones(dct, replacement="Could not collect"): + def replace_nones(dct, replacement='Could not collect'): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true="Yes", false="No"): + def replace_bools(dct, true='Yes', false='No'): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -817,48 +782,42 @@ def replace_bools(dct, true="Yes", false="No"): dct[key] = false return dct - def prepend(text, tag="[prepend]"): - lines = text.split("\n") + def prepend(text, tag='[prepend]'): + lines = text.split('\n') updated_lines = [tag + line for line in lines] - return "\n".join(updated_lines) + return '\n'.join(updated_lines) - def replace_if_empty(text, replacement="No relevant packages"): + def replace_if_empty(text, replacement='No relevant packages'): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split("\n")) > 1: - return "\n{}\n".format(string) + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( - envinfo.nvidia_gpu_models - ) + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - "cuda_runtime_version", - "nvidia_gpu_models", - "nvidia_driver_version", + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', ] - all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields - ) - if ( - TORCH_AVAILABLE - and not torch.cuda.is_available() - and all_dynamic_cuda_fields_missing - ): + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: for field in all_cuda_fields: - mutable_dict[field] = "No CUDA" + mutable_dict[field] = 'No CUDA' if envinfo.cuda_compiled_version is None: - mutable_dict["cuda_compiled_version"] = "None" + mutable_dict['cuda_compiled_version'] = 'None' # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -867,20 +826,18 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) - mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict["pip_packages"]: - mutable_dict["pip_packages"] = prepend( - mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) - ) - if mutable_dict["conda_packages"]: - mutable_dict["conda_packages"] = prepend( - mutable_dict["conda_packages"], "[conda] " - ) - mutable_dict["cpu_info"] = envinfo.cpu_info + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -904,29 +861,18 @@ def main(): output = get_pretty_env_info() print(output) - if ( - TORCH_AVAILABLE - and hasattr(torch, "utils") - and hasattr(torch.utils, "_crash_handler") - ): + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [ - os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) - ] + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime( - "%Y-%m-%d %H:%M:%S" - ) - msg = ( - "\n*** Detected a minidump at {} created on {}, ".format( - latest, creation_time - ) - + "if this is related to your bug please include it when you file a report ***" - ) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" print(msg, file=sys.stderr) -if __name__ == "__main__": + +if __name__ == '__main__': main() From d293022c477ea3b94a215315793408bcc61440cf Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Mon, 14 Jul 2025 18:57:15 -0700 Subject: [PATCH 362/457] [cutass backend] memorize parts of cache key to reduce general overhead (#158311) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158311 Approved by: https://github.com/ColinPeppler ghstack dependencies: #156781 --- torch/_inductor/codegen/cuda/cuda_template.py | 35 ++++++++++++------- torch/_inductor/codegen/cuda/gemm_template.py | 21 ++++++++++- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 2156369d56a58..cc03ccbdda863 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -3,7 +3,7 @@ import hashlib import itertools from dataclasses import dataclass -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from typing_extensions import override from unittest.mock import patch @@ -11,7 +11,6 @@ import torch from torch._inductor import config -from torch._inductor.select_algorithm import create_inputs_key from torch._inductor.utils import clear_on_fresh_cache, Placeholder from torch._logging import getArtifactLogger @@ -80,7 +79,7 @@ def _template_from_string(cls, source: str) -> Any: def supports_epilogue_fusion(op: GemmOperation) -> bool: return False - def make_key(self, op: "GemmOperation") -> str: + def make_key(self, name: str, input_key: str, layout_repr: str) -> str: """ Make a key for the code cache. The idea of the method is to cache everything that matters but doesn't include runtime param values, i.e., @@ -92,26 +91,26 @@ def make_key(self, op: "GemmOperation") -> str: return hashlib.sha256( str( ( - create_inputs_key(self.input_nodes), + input_key, self.input_reorder, # output layout, same as self.output_node.get_layout() - self.layout, + layout_repr, self.get_runtime_arg_info(), - op.configuration_name(), + name, ) ).encode("utf-8") ).hexdigest() - def generate_code_and_args(self, **kwargs) -> tuple[str, tuple[int, ...]]: + def generate_code_and_args( + self, name: str, input_key: str, layout_repr: str, **kwargs + ) -> tuple[str, tuple[int, ...]]: """ Generate code and args with caching. We cache the code even if runtime args are different. """ key: Optional[str] = None if config.cuda.enable_caching_codegen: - op = kwargs.get("op") - assert op is not None, "op is required for caching" - key = self.make_key(op) + key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: code, size_args = self.code_cache[key] @@ -160,7 +159,12 @@ def generate_code_and_args(self, **kwargs) -> tuple[str, tuple[int, ...]]: def generate( # type: ignore[override] self, + name: str, description: str, + input_key: str, + layout_repr: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], **kwargs, ) -> CUDATemplateCaller: """ @@ -175,7 +179,12 @@ def generate( # type: ignore[override] Returns: A CUDATemplateCaller object representing the generated CUDA template caller. """ - code, extra_args = self.generate_code_and_args(**kwargs) + code, extra_args = self.generate_code_and_args( + name=name, + input_key=input_key, + layout_repr=layout_repr, + **kwargs, + ) # not caching since kernel name is needed below kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] @@ -185,8 +194,8 @@ def generate( # type: ignore[override] # create the BenchmarkRequest bmreq = CUDABenchmarkRequest( kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, extra_args=extra_args, source_code=code, ) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index a38b846f7909c..bdecc07d69a51 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -10,6 +10,7 @@ import torch import torch.utils._pytree as pytree +from torch._inductor.autotune_process import TensorMeta from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode @@ -562,6 +563,16 @@ def _add_cutlass_gemm_choices( """ ops = self.gen_ops() + + # pre-computation + layout_repr: str = str(layout) + input_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.input_nodes) + ) + output_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.output_node) + ) + with dynamo_timed("CUTLASSGemmTemplate.maybe_append_choice"): for name, op in ops: for ( @@ -569,7 +580,15 @@ def _add_cutlass_gemm_choices( ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: description = f"{name} swizzle={swizzle}" self.maybe_append_choice( - choices, description=description, op=op, swizzle=swizzle + choices, + op=op, + name=name, + description=description, + input_key=self.cache_key, + layout_repr=layout_repr, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + swizzle=swizzle, ) if len(ops) == 0: From 67be2f27e17db0214d52d636945399a0c6257d65 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 21 Jul 2025 23:22:34 +0000 Subject: [PATCH 363/457] [CI][lintrunner] Only run on non deleted changed files (#158794) My PR was failing lint because I removed a file, and then lintrunner would try to run on the deleted file and error, so this changes how the changed files are retrieved to only retrieve changed files that have not been removed. I don't think this is possible through `gh pr view`, so instead it uses `gh api` Testing: https://github.com/pytorch/pytorch/pull/158795 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158794 Approved by: https://github.com/seemethere --- .github/workflows/_get-changed-files.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_get-changed-files.yml b/.github/workflows/_get-changed-files.yml index 2d3b800f0757b..55712b0652702 100644 --- a/.github/workflows/_get-changed-files.yml +++ b/.github/workflows/_get-changed-files.yml @@ -27,7 +27,7 @@ jobs: PR_NUMBER="${{ github.event.number }}" # Use gh CLI to get changed files in the PR with explicit repo - CHANGED_FILES=$(gh pr view "$PR_NUMBER" --repo "${{ github.repository }}" --json files --jq '.files[].path' | tr '\n' ' ' | sed 's/ $//') + CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//') if [ -z "$CHANGED_FILES" ]; then echo "No changed files found, setting to '*'" @@ -40,4 +40,4 @@ jobs: else echo "Not in PR context, setting changed files to '*'" echo "changed-files=*" >> "$GITHUB_OUTPUT" - fi \ No newline at end of file + fi From 187c2deb408275f980a8a5a73a522767ddb9bd30 Mon Sep 17 00:00:00 2001 From: zpcore Date: Mon, 21 Jul 2025 23:26:03 +0000 Subject: [PATCH 364/457] Fix clamp(min/max) strategy (#158619) Part of plan https://github.com/pytorch/pytorch/issues/157495. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158619 Approved by: https://github.com/wanchaol --- test/distributed/tensor/test_dtensor_ops.py | 4 ---- torch/distributed/tensor/_ops/_pointwise_ops.py | 6 ++++++ torch/distributed/tensor/_ops/utils.py | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index ba43335d1ddcb..1e94cb7e359bc 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -119,9 +119,6 @@ def wrapped(fn): xfail("cholesky_inverse"), xfail("cholesky_solve"), xfail("chunk"), - xfail("clamp"), - xfail("clamp_max"), - xfail("clamp_min"), xfail("combinations"), xfail("complex"), xfail("constant_pad_nd"), @@ -317,7 +314,6 @@ def wrapped(fn): xfail("nn.functional.multi_head_attention_forward"), xfail("nn.functional.multilabel_margin_loss"), xfail("nn.functional.multilabel_soft_margin_loss"), - xfail("nn.functional.normalize"), xfail("nn.functional.pad", "constant"), xfail("nn.functional.pad", "reflect"), xfail("nn.functional.pad", "replicate"), diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index d506226499838..46fc8fbc0d990 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -134,8 +134,14 @@ aten.ceil.out, aten.ceil_.default, aten.clamp.default, + aten.clamp.Tensor, aten.clamp.out, aten.clamp_.default, + aten.clamp_.Tensor, + aten.clamp_min.default, + aten.clamp_min.Tensor, + aten.clamp_max.default, + aten.clamp_max.Tensor, aten.clip.default, aten.clip.out, aten.clip_.default, diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index d1c604d2976dd..8e07d0d6c1f72 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -134,6 +134,8 @@ def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: for i, placement in enumerate(spec.placements): if placement.is_shard(): shard_dim = cast(Shard, placement).dim + if shard_dim >= len(shape): + return False shards_map[shard_dim] *= spec.mesh.size(i) for i, dim_size in enumerate(shape): From 08540b13c6a97908b7f4d77e504cc572db9e78f3 Mon Sep 17 00:00:00 2001 From: Raymond Li Date: Mon, 21 Jul 2025 23:34:50 +0000 Subject: [PATCH 365/457] Use cuda error code instead of error text in get_cuda_error_help (#158688) Use cudaError_t and switch through the enum to prevent impact by upstream changes in wording Pull Request resolved: https://github.com/pytorch/pytorch/pull/158688 Approved by: https://github.com/q10, https://github.com/aorenste --- c10/cuda/CUDAException.cpp | 2 +- c10/cuda/CUDAMiscFunctions.cpp | 17 +++++++++++++---- c10/cuda/CUDAMiscFunctions.h | 3 ++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 5eb54b2454539..457d35f020bbe 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -30,7 +30,7 @@ void c10_cuda_check_implementation( check_message.append("CUDA error: "); const char* error_string = cudaGetErrorString(cuda_error); check_message.append(error_string); - check_message.append(c10::cuda::get_cuda_error_help(error_string)); + check_message.append(c10::cuda::get_cuda_error_help(cuda_error)); check_message.append(c10::cuda::get_cuda_check_suffix()); check_message.append("\n"); if (include_device_assertions) { diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index 170d53398195f..b1b6170f891e9 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -7,11 +8,19 @@ namespace c10::cuda { // Explain common CUDA errors // NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) -std::string get_cuda_error_help(const char* error_string) noexcept { +std::string get_cuda_error_help(cudaError_t error) noexcept { std::string help_text; - if (strstr(error_string, "invalid device ordinal")) { - help_text.append( - "\nGPU device may be out of range, do you have enough GPUs?"); + switch (error) { + case cudaErrorInvalidDevice: + help_text.append( + "\nGPU device may be out of range, do you have enough GPUs?"); + break; + default: + help_text.append("\nSearch for `") + .append(cudaGetErrorName(error)) + .append( + "' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information."); + break; } return help_text; } diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index 26a15d85a61e2..ec1114935457e 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -3,12 +3,13 @@ // CUDAExceptions.h #include +#include #include #include namespace c10::cuda { -C10_CUDA_API std::string get_cuda_error_help(const char*) noexcept; +C10_CUDA_API std::string get_cuda_error_help(cudaError_t) noexcept; C10_CUDA_API const char* get_cuda_check_suffix() noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda From 2c37acfd891298bd3e1f60fe5c50d3ef8146292d Mon Sep 17 00:00:00 2001 From: Huamin Li Date: Mon, 21 Jul 2025 23:42:40 +0000 Subject: [PATCH 366/457] [AOTI][CPU] Consider bias=None case for fbgemm_linear_fp16_weight (#158535) Test Plan: Rollback Plan: Differential Revision: D78458214 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158535 Approved by: https://github.com/houseroad, https://github.com/henryoier, https://github.com/jingsh --- aten/src/ATen/native/QuantizedLinear.cpp | 12 ++++++++---- aten/src/ATen/native/native_functions.yaml | 2 +- .../native/quantized/cpu/qlinear_dynamic.cpp | 8 ++++---- aten/src/ATen/native/quantized/library.cpp | 2 +- test/inductor/test_aot_inductor.py | 17 +++++++++++++++++ test/inductor/test_aot_inductor_custom_ops.py | 1 + test/quantization/core/test_quantized_op.py | 17 +++++++++++++++-- torch/_inductor/decomposition.py | 2 +- torch/csrc/inductor/aoti_torch/c/shim.h | 2 +- torch/csrc/inductor/aoti_torch/shim_common.cpp | 7 ++++--- 10 files changed, 53 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index 037287a06c493..d6f1d462b6b99 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -409,7 +409,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, - const Tensor& bias) { + const std::optional& bias) { TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") @@ -430,7 +430,6 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows()) TORCH_CHECK(input.dim() >= 2); - TORCH_CHECK(bias.dim() == 1); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); @@ -449,7 +448,12 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( output.data_ptr()); // Add bias term - output.add_(bias); + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias); + const Tensor& bias_ = *bias_maybe_owned; + if (bias_.defined()) { + TORCH_CHECK(bias_.dim() == 1); + output.add_(bias_); + } return output; } @@ -551,7 +555,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, - const Tensor& bias) { + const std::optional& bias) { TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 79b7e07e2284b..e0dc1b616013e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3432,7 +3432,7 @@ - func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor -- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor +- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor - func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index e2d5278d57920..4ed50f6f8735a 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -888,7 +888,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor input, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -908,7 +908,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor input, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -929,7 +929,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor /* input */, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -940,7 +940,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor /* input */, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index cb19ec10ce045..550280dbf6d3e 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -142,7 +142,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor? bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_tanh(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 49226013d81d2..c8281e1b505a0 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1455,6 +1455,22 @@ def forward(self, x): with config.patch({"aot_inductor.use_runtime_constant_folding": True}): self.check_model(Model(self.device), example_inputs) + @skipIfNoFBGEMM + def test_quantized_linear_bias_none(self): + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn(10, 10, device=device) + + def forward(self, x): + return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight( + x, self.weight, None + ) + + example_inputs = (torch.randn(10, 10, device=self.device),) + with config.patch({"aot_inductor.use_runtime_constant_folding": True}): + self.check_model(Model(self.device), example_inputs) + @skipIfNoFBGEMM def test_quanatized_int8_linear(self): class Model(torch.nn.Module): @@ -6714,6 +6730,7 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_gpu(("cuda", "xpu")), "test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")), + "test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")), # No scaled_dot_product_efficient_attention implementation for XPU yet. "test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)), # No fft implementation for XPU yet. diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index fcbaeed297a33..aa3c589b45467 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -512,6 +512,7 @@ def fail_cuda(is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_cuda(), "test_quanatized_int8_linear": fail_cuda(), + "test_quantized_linear_bias_none": fail_cuda(), } diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d90d77b547867..4a1cb8f45814f 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3550,14 +3550,15 @@ def test_wrapped_fbgemm_linear_fp16(self): (2, 4), # batch_size (4, 5), # input_channels (4, 7), # output_channels + (True, False), # bias None or not ) - for batch_size, input_channels, output_channels in options: + for batch_size, input_channels, output_channels, bias_is_none in options: pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16 linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight x = torch.randn(batch_size, input_channels) w = torch.randn(output_channels, input_channels) - bias = torch.randn(output_channels) + bias = torch.randn(output_channels) if not bias_is_none else None w_packed = pack_op(w) out = linear_op(x, w_packed, bias, output_channels) @@ -3591,6 +3592,18 @@ def func(X, W, B): self.assertEqual(ref_out, compiled_out) + def func(X, W): + packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, None, W.size(0)) + + ref_out = func(x, w) + + compiled = torch.compile(func) + compiled_out = compiled(x, w) + + self.assertEqual(ref_out, compiled_out) + + """Tests the correctness of the dynamic quantized lstm/gru.""" def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range): diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 08c3abc9f23f9..b81e6edbb54e2 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -701,7 +701,7 @@ def randint( def linear_dynamic_fp16_unpacked_weight( input: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index a155f6bb621f1..9d512ce1f4817 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -365,7 +365,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, - AtenTensorHandle bias, + AtenTensorHandle bias, // optional argument int64_t out_channel, AtenTensorHandle* out); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index dc6e52b0c4db1..a33198fd1ba06 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -981,16 +981,17 @@ AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, - AtenTensorHandle bias, + AtenTensorHandle bias, // optional argument int64_t out_channel, AtenTensorHandle* out) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input); at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight); - at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); + auto optional_bias_tensor = + pointer_to_optional(tensor_handle_to_tensor_pointer(bias)); *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation( - *input_tensor, *weight_tensor, *bias_tensor)); + *input_tensor, *weight_tensor, optional_bias_tensor)); }); } From 9281625a9b5c8f3912626e933bcc2639e7cadd3e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 00:12:49 +0000 Subject: [PATCH 367/457] Revert "Setup TorchBench in Docker (#158613)" This reverts commit cab28330f8c49cdb66d6a299755dc09c87c14a9d. Reverted https://github.com/pytorch/pytorch/pull/158613 on behalf of https://github.com/ZainRizvi due to Seems to have broken trunk. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/16429779764/job/46430634676) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/b3c868d603e8f7b6661c93cd3d50c9a7b213ad6c) ([comment](https://github.com/pytorch/pytorch/pull/158613#issuecomment-3100023071)) --- .ci/docker/build.sh | 2 +- .../common/install_inductor_benchmark_deps.sh | 28 ++----------------- .ci/docker/requirements-ci.txt | 1 + .ci/docker/ubuntu-rocm/Dockerfile | 3 +- .ci/docker/ubuntu/Dockerfile | 3 +- .ci/pytorch/common_utils.sh | 24 ++++++++++++++++ .ci/pytorch/test.sh | 22 ++++++++++----- .../ci_commit_pins/torchbench.txt | 0 8 files changed, 45 insertions(+), 38 deletions(-) rename {.ci/docker => .github}/ci_commit_pins/torchbench.txt (100%) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d8de423682004..d6cba6659db7a 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -276,7 +276,7 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 VISION=yes diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 2e0780f889e17..7312dce170db2 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -15,35 +15,11 @@ function install_timm() { commit=$(get_pinned_commit timm) pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}" -} - -function install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench - pushd torchbench - git checkout "$commit" - - python install.py --continue_on_fail - - # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 - # is regressing speedup metric. This needs to be investigated further - pip install transformers==4.38.1 - - echo "Print all dependencies after TorchBench is installed" - python -mpip freeze - popd + # Clean up + conda_run pip uninstall -y torch torchvision triton } # Pango is needed for weasyprint which is needed for doctr conda_install pango - -# Stable packages are ok here, just to satisfy TorchBench check -pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 - -install_torchbench install_huggingface install_timm - -# Clean up -conda_run pip uninstall -y torch torchvision torchaudio triton diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 944b1fb35b36e..fb773ff324af8 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -361,6 +361,7 @@ pwlf==2.2.1 #Pinned versions: 2.2.1 #test that import: test_sac_estimator.py + # To build PyTorch itself pyyaml pyzstd diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index 8f2cc6eef9581..2528da07c69e3 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -98,9 +98,8 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt -COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt # (optional) Install non-default Ninja version ARG NINJA_VERSION diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 077910cef9f35..27c466dd8d41d 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -98,9 +98,8 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface.txt huggingface.txt COPY ci_commit_pins/timm.txt timm.txt -COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt ARG TRITON ARG TRITON_CPU diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 046f0e1597e65..9075fe5fb56f8 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -258,6 +258,30 @@ function clone_pytorch_xla() { fi } +function checkout_install_torchbench() { + local commit + commit=$(get_pinned_commit torchbench) + git clone https://github.com/pytorch/benchmark torchbench + pushd torchbench + git checkout "$commit" + + if [ "$1" ]; then + python install.py --continue_on_fail models "$@" + else + # Occasionally the installation may fail on one model but it is ok to continue + # to install and test other models + python install.py --continue_on_fail + fi + + # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488 + # is regressing speedup metric. This needs to be investigated further + pip install transformers==4.38.1 + + echo "Print all dependencies after TorchBench is installed" + python -mpip freeze + popd +} + function install_torchao() { local commit commit=$(get_pinned_commit torchao) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 4f28297b5bce8..ad6a48b2528e4 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1668,11 +1668,13 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then elif [[ "${TEST_CONFIG}" == cachebench ]]; then install_torchaudio install_torchvision - PYTHONPATH=/torchbench test_cachebench + checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco + PYTHONPATH=$(pwd)/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then install_torchaudio install_torchvision - PYTHONPATH=/torchbench test_verify_cachebench + checkout_install_torchbench nanogpt + PYTHONPATH=$(pwd)/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then install_torchaudio install_torchvision @@ -1681,22 +1683,28 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then - PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf + checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer + PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then - PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf + checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \ + llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \ + functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0 + PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then - TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest + checkout_install_torchbench + TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest else + checkout_install_torchbench # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in if [[ "${TEST_CONFIG}" != *cpu* ]]; then install_torchrec_and_fbgemm fi - PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" + PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then install_torchvision - PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" + PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti fi diff --git a/.ci/docker/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt similarity index 100% rename from .ci/docker/ci_commit_pins/torchbench.txt rename to .github/ci_commit_pins/torchbench.txt From 350d6af52c76481d0f386208b6b86be93b7ff22d Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Tue, 22 Jul 2025 00:23:10 +0000 Subject: [PATCH 368/457] [AOTI] add windows support for get_cpp_compile_command (#158732) add windows support for `get_cpp_compile_command`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158732 Approved by: https://github.com/desertfire --- .../aoti_package/model_package_loader.cpp | 82 ++++++++++++++----- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 629dc8cb2ae80..4018d9b00a75e 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -109,6 +109,14 @@ const char* extension_file_ext() { return ".so"; #endif } + +bool _is_windows_os() { +#ifdef _WIN32 + return true; +#else + return false; +#endif +} } // namespace namespace torch::inductor { @@ -143,7 +151,8 @@ std::tuple get_cpp_compile_command( source_args += source + " "; } - std::string file_ext = compile_only ? ".o" : ".so"; + std::string file_ext = + compile_only ? object_file_ext() : extension_file_ext(); std::string target_file = output_dir + filename + file_ext; std::string target_dir = output_dir; if (target_dir.empty()) { @@ -153,32 +162,43 @@ std::tuple get_cpp_compile_command( std::string cflags_args; for (auto& arg : compile_options["cflags"]) { - cflags_args += "-" + arg.get() + " "; + cflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; } std::string definitions_args; for (auto& arg : compile_options["definitions"]) { - definitions_args += "-D " + arg.get() + " "; + definitions_args += + _is_windows_os() ? "/D" : "-D " + arg.get() + " "; } std::string include_dirs_args; for (auto& arg : compile_options["include_dirs"]) { - include_dirs_args += "-I" + arg.get() + " "; + include_dirs_args += + _is_windows_os() ? "/I" : "-I" + arg.get() + " "; } std::string ldflags_args; for (auto& arg : compile_options["ldflags"]) { - ldflags_args += "-" + arg.get() + " "; + ldflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; } std::string libraries_dirs_args; for (auto& arg : compile_options["libraries_dirs"]) { - libraries_dirs_args += "-L" + arg.get() + " "; + if (_is_windows_os()) { + libraries_dirs_args += + fmt::format("/LIBPATH:\"{}\"", arg.get()) + " "; + } else { + libraries_dirs_args += "-L" + arg.get() + " "; + } } std::string libraries_args; for (auto& arg : compile_options["libraries"]) { - libraries_args += "-l" + arg.get() + " "; + if (_is_windows_os()) { + libraries_args += fmt::format("{}.lib", arg.get()) + " "; + } else { + libraries_args += "-l" + arg.get() + " "; + } } std::string passthrough_parameters_args; @@ -191,21 +211,39 @@ std::tuple get_cpp_compile_command( passthrough_parameters_args += arg_str + " "; } - std::string compile_only_arg = compile_only ? "-c" : ""; - - std::string cmd = normalize_path_separator(fmt::format( - "{} {} {} {} {} {} {} {} {} {} -o {}", - compiler, - source_args, - definitions_args, - cflags_args, - include_dirs_args, - passthrough_parameters_args, - ldflags_args, - libraries_args, - libraries_dirs_args, - compile_only_arg, - target_file)); + std::string compile_only_arg = + compile_only ? (_is_windows_os() ? "/c" : "-c") : ""; + + std::string cmd; + if (_is_windows_os()) { + cmd = normalize_path_separator(fmt::format( + "{} {} {} {} {} {} /LD /Fe{} {} /link {} {} {}", + compiler, + include_dirs_args, + definitions_args, + cflags_args, + source_args, + passthrough_parameters_args, + target_file, + compile_only_arg, + libraries_dirs_args, + libraries_args, + ldflags_args)); + } else { + cmd = normalize_path_separator(fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file)); + } return std::make_tuple(cmd, target_file); } From 63413113332b72d461c1ab5305ac7d439cf89df1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 01:01:41 +0000 Subject: [PATCH 369/457] Revert "Add unified memory APIs for torch.accelerator (#152932)" This reverts commit 2ad5c25cfc603c3656e6699d6137419dbb009495. Reverted https://github.com/pytorch/pytorch/pull/152932 on behalf of https://github.com/ZainRizvi due to Very sorry but this is still breaking internally. @albanD would you be able to help get this past the finish line? D78496124 has more details on the failure and the workaround might be to do something like what's in D78684669. To validate the fixes internally, you can follow the instructions here to ghimport the changes: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3100195370)) --- aten/src/ATen/DeviceAccelerator.h | 22 ---- docs/source/accelerator.md | 23 ---- torch/_C/__init__.pyi.in | 5 - torch/accelerator/__init__.py | 18 --- torch/accelerator/memory.py | 201 ------------------------------ torch/csrc/DeviceAccelerator.cpp | 64 ---------- torch/cuda/memory.py | 4 +- 7 files changed, 2 insertions(+), 335 deletions(-) delete mode 100644 torch/accelerator/memory.py diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f23b35047fcc8..f37e492c861fe 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -73,27 +72,6 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); // original device index that was active before the change. TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); -TORCH_API inline void emptyCache() { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->emptyCache(); -} - -TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); -} - -TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); -} - -TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { - const auto device_type = getAccelerator(true).value(); - at::getDeviceAllocator(device_type)->resetPeakStats(device_index); -} - } // namespace at::accelerator namespace at { diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index ce593a9acf518..c6f2fb1080400 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -25,26 +25,3 @@ synchronize device_index ``` - -```{eval-rst} -.. automodule:: torch.accelerator.memory -``` -```{eval-rst} -.. currentmodule:: torch.accelerator.memory -``` - -## Memory management -```{eval-rst} -.. autosummary:: - :toctree: generated - :nosignatures: - - empty_cache - max_memory_allocated - max_memory_reserved - memory_allocated - memory_reserved - memory_stats - reset_accumulated_memory_stats - reset_peak_memory_stats -``` diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 7f88b86a7eaf2..dea17d26ef21f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2431,11 +2431,6 @@ def _accelerator_getStream(device_index: _int) -> Stream: ... def _accelerator_synchronizeDevice(device_index: _int) -> None: ... def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... -def _accelerator_isAllocatorInitialized() -> _bool: ... -def _accelerator_emptyCache() -> None: ... -def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... -def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... -def _accelerator_resetPeakStats(device_index: _int) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 4d1a78df1f74c..e9e48f1cf3061 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -8,16 +8,6 @@ import torch from ._utils import _device_t, _get_device_index -from .memory import ( - empty_cache, - max_memory_allocated, - max_memory_reserved, - memory_allocated, - memory_reserved, - memory_stats, - reset_accumulated_memory_stats, - reset_peak_memory_stats, -) __all__ = [ @@ -25,17 +15,9 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", - "empty_cache", "device_count", "device_index", "is_available", - "max_memory_allocated", - "max_memory_reserved", - "memory_allocated", - "memory_reserved", - "memory_stats", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", "set_device_idx", # deprecated "set_device_index", "set_stream", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py deleted file mode 100644 index d34a11a3a02e5..0000000000000 --- a/torch/accelerator/memory.py +++ /dev/null @@ -1,201 +0,0 @@ -from collections import OrderedDict -from typing import Any - -import torch - -from ._utils import _device_t, _get_device_index - - -__all__ = [ - "empty_cache", - "max_memory_allocated", - "max_memory_reserved", - "memory_allocated", - "memory_reserved", - "memory_stats", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", -] - - -def empty_cache() -> None: - r"""Release all unoccupied cached memory currently held by the caching - allocator so that those can be used in other application. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - if not torch._C._accelerator_isAllocatorInitialized(): - return - torch._C._accelerator_emptyCache() - - -def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: - r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. - - The return value of this function is a dictionary of statistics, each of - which is a non-negative integer. - - Core statistics: - - - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of allocation requests received by the memory allocator. - - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of allocated memory. - - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of reserved segments from device memory allocation. - - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of reserved memory. - - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of active memory blocks. - - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of active memory. - - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - number of inactive, non-releasable memory blocks. - - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: - amount of inactive, non-releasable memory. - - For these core statistics, values are broken down as follows. - - Pool type: - - - ``all``: combined statistics across all memory pools. - - ``large_pool``: statistics for the large allocation pool - (as of June 2025, for size >= 1MB allocations). - - ``small_pool``: statistics for the small allocation pool - (as of June 2025, for size < 1MB allocations). - - Metric type: - - - ``current``: current value of this metric. - - ``peak``: maximum value of this metric. - - ``allocated``: historical total increase in this metric. - - ``freed``: historical total decrease in this metric. - - In addition to the core statistics, we also provide some simple event - counters: - - - ``"num_alloc_retries"``: number of failed device memory allocation calls that - result in a cache flush and retry. - - ``"num_ooms"``: number of out-of-memory errors thrown. - - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls. - - ``"num_device_alloc"``: number of device memory allocation calls. - - ``"num_device_free"``: number of device memory free calls. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - if not torch._C._accelerator_isAllocatorInitialized(): - return OrderedDict() - device_index = _get_device_index(device_index, optional=True) - stats = torch._C._accelerator_getDeviceStats(device_index) - flat_stats = [] - - def flatten(prefix: str, value: Any) -> None: - if isinstance(value, dict): - for k, v in value.items(): - nested_prefix = f"{prefix}.{k}" if prefix else k - flatten(nested_prefix, v) - else: - flat_stats.append((prefix, value)) - - flatten("", stats) - flat_stats.sort() - return OrderedDict(flat_stats) - - -def memory_allocated(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` device memory occupied by tensors - in bytes for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("allocated_bytes.all.current", 0) - - -def max_memory_allocated(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` maximum device memory occupied by tensors - in bytes for a given device index. - - By default, this returns the peak allocated memory since the beginning of - this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to - reset the starting point in tracking this metric. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("allocated_bytes.all.peak", 0) - - -def memory_reserved(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` device memory managed by the caching allocator - in bytes for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("reserved_bytes.all.current", 0) - - -def max_memory_reserved(device_index: _device_t = None, /) -> int: - r"""Return the current :ref:`accelerator` maximum device memory managed by the caching allocator - in bytes for a given device index. - - By default, this returns the peak cached memory since the beginning of this - program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset - the starting point in tracking this metric. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - """ - return memory_stats(device_index).get("reserved_bytes.all.peak", 0) - - -def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None: - r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator` - memory allocator for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - device_index = _get_device_index(device_index, optional=True) - return torch._C._accelerator_resetAccumulatedStats(device_index) - - -def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: - r"""Reset the "peak" stats tracked by the current :ref:`accelerator` - memory allocator for a given device index. - - Args: - device_index (:class:`torch.device`, str, int, optional): the index of the device to target. - If not given, use :func:`torch.accelerator.current_device_index` by default. - If a :class:`torch.device` or str is provided, its type must match the current - :ref:`accelerator` device type. - - .. note:: This function is a no-op if the memory allocator for the current - :ref:`accelerator ` has not been initialized. - """ - device_index = _get_device_index(device_index, optional=True) - return torch._C._accelerator_resetPeakStats(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index dc3da8881a715..37fac325d3167 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -72,70 +72,6 @@ void initModule(PyObject* module) { torch::utils::maybe_initialize_device(device_type); return at::accelerator::maybeExchangeDevice(device_index); }); - - m.def("_accelerator_isAllocatorInitialized", []() { - const auto device_type = at::accelerator::getAccelerator(true).value(); - return at::getDeviceAllocator(device_type)->initialized(); - }); - - m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); - - m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { - using c10::CachingAllocator::Stat; - using c10::CachingAllocator::StatArray; - using c10::CachingAllocator::StatType; - using c10::CachingDeviceAllocator::DeviceStats; - - const auto stats = at::accelerator::getDeviceStats(device_index); - const auto stat_to_dict = [](const Stat& stat) -> py::dict { - py::dict dict; - dict["current"] = stat.current; - dict["peak"] = stat.peak; - dict["allocated"] = stat.allocated; - dict["freed"] = stat.freed; - return dict; - }; - - const auto stat_array_to_dict = [=](const StatArray& stats) -> py::dict { - const std::array(StatType::NUM_TYPES)> - kStatTypeNames = {"all", "small_pool", "large_pool"}; - py::dict dict; - for (const auto i : c10::irange(kStatTypeNames.size())) { - dict[kStatTypeNames[i]] = stat_to_dict(stats[i]); - } - return dict; - }; - - py::dict result; - result["num_alloc_retries"] = stats.num_alloc_retries; - result["num_ooms"] = stats.num_ooms; - result["max_split_size"] = stats.max_split_size; - result["num_sync_all_streams"] = stats.num_sync_all_streams; - result["num_device_alloc"] = stats.num_device_alloc; - result["num_device_free"] = stats.num_device_free; - result["allocated_bytes"] = stat_array_to_dict(stats.allocated_bytes); - result["reserved_bytes"] = stat_array_to_dict(stats.reserved_bytes); - result["active_bytes"] = stat_array_to_dict(stats.active_bytes); - result["requested_bytes"] = stat_array_to_dict(stats.requested_bytes); - result["allocation"] = stat_array_to_dict(stats.allocation); - result["segment"] = stat_array_to_dict(stats.segment); - result["active"] = stat_array_to_dict(stats.active); - result["inactive_split"] = stat_array_to_dict(stats.inactive_split); - result["inactive_split_bytes"] = - stat_array_to_dict(stats.inactive_split_bytes); - result["oversize_allocations"] = stat_to_dict(stats.oversize_allocations); - result["oversize_segments"] = stat_to_dict(stats.oversize_segments); - return result; - }); - - m.def( - "_accelerator_resetAccumulatedStats", [](c10::DeviceIndex device_index) { - at::accelerator::resetAccumulatedStats(device_index); - }); - - m.def("_accelerator_resetPeakStats", [](c10::DeviceIndex device_index) { - at::accelerator::resetPeakStats(device_index); - }); } } // namespace torch::accelerator diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 08bfced67532b..3a2e1bb0f8909 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -255,9 +255,9 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: - ``all``: combined statistics across all memory pools. - ``large_pool``: statistics for the large allocation pool - (as of June 2025, for size >= 1MB allocations). + (as of October 2019, for size >= 1MB allocations). - ``small_pool``: statistics for the small allocation pool - (as of June 2025, for size < 1MB allocations). + (as of October 2019, for size < 1MB allocations). Metric type: From 95b658427df55f36f638186de9ed4115d4d99941 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 01:01:41 +0000 Subject: [PATCH 370/457] Revert "Add DeviceAllocator as the base device allocator (#138222)" This reverts commit 1179e333237b02ed8fe2ba10cb9a23adf98d7d7a. Reverted https://github.com/pytorch/pytorch/pull/138222 on behalf of https://github.com/ZainRizvi due to Very sorry but this is still breaking internally. @albanD would you be able to help get this past the finish line? D78496124 has more details on the failure and the workaround might be to do something like what's in D78684669. To validate the fixes internally, you can follow the instructions here to ghimport the changes: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/138222#issuecomment-3100195370)) --- aten/src/ATen/cuda/CUDAGraph.cpp | 1 + aten/src/ATen/cuda/CUDAGraph.h | 1 - .../hip/impl/HIPAllocatorMasqueradingAsCUDA.h | 26 ++------- .../HIPCachingAllocatorMasqueradingAsCUDA.cpp | 7 +-- c10/core/CachingDeviceAllocator.cpp | 10 ---- c10/core/CachingDeviceAllocator.h | 53 ------------------- c10/cuda/CUDACachingAllocator.cpp | 1 - c10/cuda/CUDACachingAllocator.h | 19 +++---- c10/cuda/CUDAGraphsC10Utils.h | 6 +++ c10/xpu/XPUCachingAllocator.cpp | 19 +++---- 10 files changed, 27 insertions(+), 116 deletions(-) delete mode 100644 c10/core/CachingDeviceAllocator.cpp diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 2800e505a9b76..7fba7c4c7424c 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index 4f2aa31dd1c35..c8cae16b624fe 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include diff --git a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h index c1ecea34db16f..39ab441478e8f 100644 --- a/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include // Use of c10::hip namespace here makes hipification easier, because @@ -10,10 +10,10 @@ namespace c10::hip { // Takes a valid HIPAllocator (of any sort) and turns it into // an allocator pretending to be a CUDA allocator. See // Note [Masquerading as CUDA] -class HIPAllocatorMasqueradingAsCUDA final : public DeviceAllocator { - DeviceAllocator* allocator_; +class HIPAllocatorMasqueradingAsCUDA final : public Allocator { + Allocator* allocator_; public: - explicit HIPAllocatorMasqueradingAsCUDA(DeviceAllocator* allocator) + explicit HIPAllocatorMasqueradingAsCUDA(Allocator* allocator) : allocator_(allocator) {} DataPtr allocate(size_t size) override { DataPtr r = allocator_->allocate(size); @@ -26,24 +26,6 @@ class HIPAllocatorMasqueradingAsCUDA final : public DeviceAllocator { void copy_data(void* dest, const void* src, std::size_t count) const final { allocator_->copy_data(dest, src, count); } - bool initialized() override { - return allocator_->initialized(); - } - void emptyCache(MempoolId_t mempool_id = {0, 0}) { - allocator_->emptyCache(mempool_id); - } - void recordStream(const DataPtr& ptr, c10::Stream stream) { - allocator_->recordStream(ptr, stream); - } - CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) { - return allocator_->getDeviceStats(device); - } - void resetAccumulatedStats(c10::DeviceIndex device) { - allocator_->resetAccumulatedStats(device); - } - void resetPeakStats(c10::DeviceIndex device) { - allocator_->resetPeakStats(device); - } }; } // namespace c10::hip diff --git a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp index 19bc0a6b34e54..46f7d247293a1 100644 --- a/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp +++ b/aten/src/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.cpp @@ -4,9 +4,8 @@ namespace c10 { namespace hip { namespace HIPCachingAllocatorMasqueradingAsCUDA { -static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get()); - Allocator* get() { + static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get()); return &allocator; } @@ -14,9 +13,5 @@ void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsC HIPCachingAllocator::recordStream(ptr, stream.hip_stream()); } -// Register this HIP allocator as CUDA allocator to enable access through both -// c10::GetAllocator(kCUDA) and c10::getDeviceAllocator(kCUDA) APIs -REGISTER_ALLOCATOR(kCUDA, &allocator) - } // namespace HIPCachingAllocatorMasqueradingAsCUDA }} // namespace c10::hip diff --git a/c10/core/CachingDeviceAllocator.cpp b/c10/core/CachingDeviceAllocator.cpp deleted file mode 100644 index 582efd59cf1b1..0000000000000 --- a/c10/core/CachingDeviceAllocator.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace c10 { - -// Ensures proper DLL export of this pure virtual base class on Windows, -// since it's mainly used in other DLLs outside c10.dll. -DeviceAllocator::DeviceAllocator() = default; -DeviceAllocator::~DeviceAllocator() = default; - -} // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 0bec03ae417fa..b23490de693a8 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -1,7 +1,6 @@ #pragma once #include -#include namespace c10::CachingDeviceAllocator { @@ -60,55 +59,3 @@ struct DeviceStats { }; } // namespace c10::CachingDeviceAllocator - -namespace c10 { - -using CaptureId_t = unsigned long long; - -// first is set if the instance is created by Graph mode capture_begin. -// second is set if the instance is created by Graph mode graph_pool_handle. -using MempoolId_t = std::pair; - -struct C10_API DeviceAllocator : public c10::Allocator { - DeviceAllocator(); - ~DeviceAllocator() override; - - // Returns true if the allocator has been properly initialized and is ready - // for use - virtual bool initialized() = 0; - - // Releases all cached device memory from the specified memory pool back to - // the system - virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; - - // Associates a memory allocation with a stream to establish dependency - // tracking. Prevents memory reuse until all operations on the specified - // stream complete - virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0; - - // Retrieves comprehensive memory statistics for the specified device, - // including allocation patterns, usage metrics - virtual CachingDeviceAllocator::DeviceStats getDeviceStats( - c10::DeviceIndex device) = 0; - - // Resets cumulative allocation statistics for the specified device to zero - virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; - - // Resets peak memory usage statistics for the specified device - virtual void resetPeakStats(c10::DeviceIndex device) = 0; -}; - -// This function is used to get the DeviceAllocator for a specific device type -// and keep backward compatibility with c10::GetAllocator. -C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) { - TORCH_CHECK( - t != DeviceType::CPU, - "getDeviceAllocator is not supported for CPU device type."); - auto* allocator = c10::GetAllocator(t); - auto* device_allocator = dynamic_cast(allocator); - TORCH_INTERNAL_ASSERT( - device_allocator, "Allocator for ", t, " is not a DeviceAllocator."); - return device_allocator; -} - -} // namespace c10 diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 91ea6d9d9bd4d..4d58c11c5c9bc 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -4179,7 +4179,6 @@ struct BackendStaticInitializer { BackendStaticInitializer() { auto r = parseEnvForBackend(); - at::SetAllocator(kCUDA, r, 0); allocator.store(r); } }; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 5e412342b17d0..a6fa61110d675 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -202,24 +202,25 @@ struct ShareableHandle { std::string handle; }; -class CUDAAllocator : public DeviceAllocator { +class CUDAAllocator : public Allocator { public: virtual void* raw_alloc(size_t nbytes) = 0; virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; + virtual bool initialized() = 0; virtual double getMemoryFraction(c10::DeviceIndex device) = 0; virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0; + virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0; virtual void enable(bool value) = 0; virtual bool isEnabled() const = 0; virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0; virtual void* getBaseAllocation(void* ptr, size_t* size) = 0; - // Keep for BC only - virtual void recordStream(const DataPtr& ptr, CUDAStream stream) = 0; - void recordStream(const DataPtr& ptr, c10::Stream stream) override { - CUDAStream cuda_stream = CUDAStream(stream); - recordStream(ptr, cuda_stream); - } + virtual void recordStream(const DataPtr&, CUDAStream stream) = 0; + virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) = 0; + virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0; + virtual void resetPeakStats(c10::DeviceIndex device) = 0; virtual SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) = 0; virtual void beginAllocateToPool( c10::DeviceIndex device, @@ -524,10 +525,6 @@ inline void enablePeerAccess( namespace c10::cuda { -// Keep BC only -using c10::CaptureId_t; -using c10::MempoolId_t; - // MemPool represents a pool of memory in a caching allocator. Currently, // it's just the ID of the pool object maintained in the CUDACachingAllocator. // diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index 936875fd71d5c..eb29ca8bc9f02 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -9,6 +9,12 @@ namespace c10::cuda { +using CaptureId_t = unsigned long long; + +// first is set if the instance is created by CUDAGraph::capture_begin. +// second is set if the instance is created by at::cuda::graph_pool_handle. +using MempoolId_t = std::pair; + // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index a5e088515ff55..543b48f081135 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -540,7 +540,7 @@ class DeviceCachingAllocator { static void local_raw_delete(void* ptr); -class XPUAllocator : public DeviceAllocator { +class XPUAllocator : public Allocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; @@ -576,10 +576,6 @@ class XPUAllocator : public DeviceAllocator { } } - bool initialized() override { - return !device_allocators.empty(); - } - void malloc( void** devPtr, DeviceIndex device, @@ -614,13 +610,13 @@ class XPUAllocator : public DeviceAllocator { } } - void emptyCache(MempoolId_t mempool_id [[maybe_unused]] = {0, 0}) override { + void emptyCache() { for (auto& da : device_allocators) { da->emptyCache(); } } - void recordStream(const DataPtr& ptr, c10::Stream stream) override { + void recordStream(const DataPtr& ptr, XPUStream stream) { if (!ptr.get()) { return; } @@ -630,8 +626,7 @@ class XPUAllocator : public DeviceAllocator { Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); - c10::xpu::XPUStream xpu_stream{stream}; - device_allocators[block->device]->recordStream(block, xpu_stream); + device_allocators[block->device]->recordStream(block, stream); } DataPtr allocate(size_t size) override { @@ -684,17 +679,17 @@ class XPUAllocator : public DeviceAllocator { ": did you call init?"); } - DeviceStats getDeviceStats(DeviceIndex device) override { + DeviceStats getDeviceStats(DeviceIndex device) { assertValidDevice(device); return device_allocators[device]->getStats(); } - void resetPeakStats(DeviceIndex device) override { + void resetPeakStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } - void resetAccumulatedStats(DeviceIndex device) override { + void resetAccumulatedStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } From abe0c9538a1abcb0528ac2107bd3ac5de628be89 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 21 Jul 2025 17:06:50 -0700 Subject: [PATCH 371/457] [BE] Fix extra-semi warnings (#158730) And prevent new ones from appearing by removing `-Wno-error=extra-semi` (not sure what was thereason behind adding the warning but not erroring on on it when building with -Werror introduced by https://github.com/pytorch/pytorch/pull/140236 ) 300+ violations of that rule were fixed by running `sed -i -e "s/});/})/" /` against `torch/nativert` Other 3p deps that needs updates: - TensorPipe - LLVM - FBGEMM Pull Request resolved: https://github.com/pytorch/pytorch/pull/158730 Approved by: https://github.com/Skylion007 --- aten/src/ATen/Context.cpp | 2 + aten/src/ATen/cuda/CachingHostAllocator.cpp | 2 +- aten/src/ATen/native/Copy.cpp | 2 + aten/src/ATen/native/EmbeddingBag.cpp | 2 + aten/src/ATen/native/QuantizedLinear.cpp | 2 + .../ao_sparse/quantized/cpu/fbgemm_utils.h | 2 + aten/src/ATen/native/cpu/utils.h | 2 + .../ATen/native/quantized/cpu/fbgemm_utils.h | 2 + .../cuda/flash_attn/flash_api.cpp | 2 + cmake/public/utils.cmake | 2 +- torch/csrc/PyInterpreterHooks.cpp | 20 + .../csrc/distributed/rpc/tensorpipe_agent.cpp | 2 + .../csrc/distributed/rpc/tensorpipe_cuda.cpp | 2 + .../csrc/distributed/rpc/tensorpipe_utils.cpp | 2 + torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 6 +- torch/csrc/jit/tensorexpr/llvm_jit.cpp | 2 + torch/csrc/jit/tensorexpr/llvm_jit.h | 2 + .../GeneratedNativeStaticDispatchKernels.cpp | 72 +-- .../GeneratedStaticDispatchKernels.cpp | 472 +++++++++--------- torch/nativert/kernels/KernelRegistry.cpp | 120 ++--- torch/nativert/kernels/NativeKernels.cpp | 22 +- torch/nativert/kernels/PrimKernelRegistry.cpp | 8 +- torch/nativert/kernels/PrimKernelRegistry.h | 2 +- 23 files changed, 399 insertions(+), 353 deletions(-) create mode 100644 torch/csrc/PyInterpreterHooks.cpp diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 8c84779f472d7..ded7743c4d860 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -14,7 +14,9 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #endif // USE_FBGEMM #if defined(__aarch64__) && !defined(C10_MOBILE) #include diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 6a80342e10240..34aa15d0c06cf 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -258,7 +258,7 @@ DECLARE_HOST_ALLOCATOR( CUDACachingHostAllocator, CUDACachingHostAllocatorImpl, raw_local_deleter, - caching_host_allocator); + caching_host_allocator) REGISTER_HOST_ALLOCATOR(at::kCUDA, &caching_host_allocator) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 4cd46f3b00285..3d388194ea49d 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -36,8 +36,10 @@ #endif #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include +C10_DIAGNOSTIC_POP() #endif namespace { diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index a38730b3388d9..150970edc5076 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -14,8 +14,10 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include +C10_DIAGNOSTIC_POP() #else #include #endif diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index d6f1d462b6b99..f4fdd395f013a 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -25,9 +25,11 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() #endif // USE_FBGEMM namespace caffe2 { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h index 1d0215fbfc5dc..9a122cd7cf05e 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -4,9 +4,11 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() namespace ao::sparse { diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index e1c7e5c607477..827c69629eb37 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -6,7 +6,9 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #endif namespace at::native { diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index e6d86cf03df13..7dc9a93365e3e 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -7,11 +7,13 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override") #include C10_DIAGNOSTIC_POP() #include +C10_DIAGNOSTIC_POP() // The struct for the packed weight matrix (PackBMatrix) and the corresponding // column offsets used for the fully connect layer, which are both prepared in diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 68451ba5ffcc8..a4e37da1a4ae9 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -32,7 +32,9 @@ #endif +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #include diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0f34603cf0231..032db8a8ab5c7 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -394,7 +394,7 @@ function(torch_compile_options libname) list(APPEND private_compile_options -Wredundant-move) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - list(APPEND private_compile_options -Wextra-semi -Wno-error=extra-semi -Wmove) + list(APPEND private_compile_options -Wextra-semi -Wmove) else() list(APPEND private_compile_options # Considered to be flaky. See the discussion at diff --git a/torch/csrc/PyInterpreterHooks.cpp b/torch/csrc/PyInterpreterHooks.cpp new file mode 100644 index 0000000000000..5e064493fd595 --- /dev/null +++ b/torch/csrc/PyInterpreterHooks.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace torch::detail { + +PyInterpreterHooks::PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs) {} + +c10::impl::PyInterpreter* PyInterpreterHooks::getPyInterpreter() const { + // Delegate to the existing implementation + return ::getPyInterpreter(); +} + +} // namespace torch::detail + +// Sigh, the registry doesn't support namespaces :( +using c10::impl::PyInterpreterHooksRegistry; +using c10::impl::RegistererPyInterpreterHooksRegistry; +using PyInterpreterHooks = torch::detail::PyInterpreterHooks; +// Register the implementation +REGISTER_PYTHON_HOOKS(PyInterpreterHooks) diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index c25e83c07c6db..1907520702503 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -8,8 +8,10 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 4c326b6a0e276..03b43184d143b 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -7,10 +7,12 @@ #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") #include #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 86308ae6cdf35..f28aefc06dee0 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -6,8 +6,10 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index d6c5590a71003..918d82579444f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -83,7 +83,7 @@ using namespace torch::jit::tensorexpr; C10_DEFINE_bool( torch_jit_llvm_use_fast_intrinsics, false, - "Use fast (but slightly less accurate) implementations of tanh and sigmoid"); + "Use fast (but slightly less accurate) implementations of tanh and sigmoid") namespace torch::jit::tensorexpr { @@ -246,7 +246,7 @@ class LLVMCodeGenImpl : public IRVisitor { std::string kernel_func_name_; #define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE); + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE) #undef LLVM_TYPE_DECLARE #if LLVM_VERSION_MAJOR >= 15 @@ -1101,7 +1101,7 @@ std::enable_if_t, llvm::Value*> getFromType( void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \ value_ = getFromType(Name##Ty_, v->value()); \ } -AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE); +AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE) #undef IMM_VISIT_DECLARE void LLVMCodeGenImpl::visit(const HalfImmPtr& v) { diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index c9a930576cdca..80d919a5674e6 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -11,6 +11,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include @@ -35,6 +36,7 @@ C10_DIAGNOSTIC_POP() #endif #include #include +C10_DIAGNOSTIC_POP() #include diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index beadbdd5e537e..19a21329b64a7 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -9,9 +9,11 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp index c33fb81604f6f..e8d7170fdf1cd 100644 --- a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp @@ -39,7 +39,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::view_as_real(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.view_as_complex.default", @@ -48,31 +48,31 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::view_as_complex(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.real.default", aten_real_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::real(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.imag.default", aten_imag_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::imag(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten._conj.default", aten__conj_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::_conj(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.conj.default", aten_conj_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::conj(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.resolve_conj.default", @@ -81,7 +81,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::resolve_conj(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.resolve_neg.default", @@ -90,7 +90,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::resolve_neg(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._neg_view.default", @@ -99,7 +99,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::_neg_view(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.diagonal.default", @@ -111,7 +111,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim2 = KernelInput(3).toInt(); KernelOutput(0) = at::native::diagonal(self, offset, dim1, dim2); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.linalg_diagonal.default", @@ -123,7 +123,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim2 = KernelInput(3).toInt(); KernelOutput(0) = at::native::linalg_diagonal(A, offset, dim1, dim2); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.expand_as.default", @@ -133,7 +133,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::expand_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.flatten.using_ints", @@ -144,7 +144,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto end_dim = KernelInput(2).toInt(); KernelOutput(0) = at::native::flatten(self, start_dim, end_dim); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.movedim.int", aten_movedim_int, { const auto& self = KernelInput(0).toTensor(); @@ -152,7 +152,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.movedim.int", aten_movedim_int, { const auto destination = KernelInput(2).toInt(); KernelOutput(0) = at::native::movedim(self, source, destination); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.moveaxis.int", aten_moveaxis_int, { const auto& self = KernelInput(0).toTensor(); @@ -160,7 +160,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.moveaxis.int", aten_moveaxis_int, { const auto destination = KernelInput(2).toInt(); KernelOutput(0) = at::native::moveaxis(self, source, destination); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.numpy_T.default", @@ -169,7 +169,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::numpy_T(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.matrix_H.default", @@ -178,19 +178,19 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::matrix_H(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mT.default", aten_mT_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::mT(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mH.default", aten_mH_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::mH(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.adjoint.default", @@ -199,13 +199,13 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::adjoint(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.ravel.default", aten_ravel_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::ravel(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.reshape_as.default", @@ -215,7 +215,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::reshape_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.detach.default", @@ -224,7 +224,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::detach(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.squeeze.default", @@ -233,20 +233,20 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::squeeze(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.squeeze.dim", aten_squeeze_dim, { const auto& self = KernelInput(0).toTensor(); const auto dim = KernelInput(1).toInt(); KernelOutput(0) = at::native::squeeze(self, dim); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.t.default", aten_t_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::t(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.transpose.int", aten_transpose_int, { const auto& self = KernelInput(0).toTensor(); @@ -254,7 +254,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.transpose.int", aten_transpose_int, { const auto dim1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::transpose(self, dim0, dim1); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.unsqueeze.default", @@ -264,7 +264,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim = KernelInput(1).toInt(); KernelOutput(0) = at::native::unsqueeze(self, dim); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.view_as.default", @@ -274,7 +274,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::view_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.positive.default", @@ -283,7 +283,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::positive(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._autocast_to_reduced_precision.default", @@ -297,7 +297,7 @@ REGISTER_NATIVE_CPU_KERNEL( KernelOutput(0) = at::native::_autocast_to_reduced_precision( self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._autocast_to_full_precision.default", @@ -309,7 +309,7 @@ REGISTER_NATIVE_CPU_KERNEL( KernelOutput(0) = at::native::_autocast_to_full_precision( self, cuda_enabled, cpu_enabled); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.swapaxes.default", @@ -320,7 +320,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto axis1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::swapaxes(self, axis0, axis1); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.swapdims.default", @@ -331,7 +331,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::swapdims(self, dim0, dim1); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.unfold.default", @@ -343,12 +343,12 @@ REGISTER_NATIVE_CPU_KERNEL( const auto step = KernelInput(3).toInt(); KernelOutput(0) = at::native::unfold(self, dimension, size, step); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.alias.default", aten_alias_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::alias(self); return; -}); +}) } // namespace torch::nativert diff --git a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp index 986eb060cb0fb..f919639f48def 100644 --- a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp @@ -41,7 +41,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.absolute.default", aten_absolute_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::absolute_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.angle.default", aten_angle_default, { const auto& self = KernelInput(0).toTensor(); @@ -52,7 +52,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.angle.default", aten_angle_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::angle_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sgn.default", aten_sgn_default, { const auto& self = KernelInput(0).toTensor(); @@ -63,7 +63,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sgn.default", aten_sgn_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sgn_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.acos.default", aten_acos_default, { const auto& self = KernelInput(0).toTensor(); @@ -74,7 +74,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.acos.default", aten_acos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::acos_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arccos.default", aten_arccos_default, { const auto& self = KernelInput(0).toTensor(); @@ -85,7 +85,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arccos.default", aten_arccos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arccos_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.add.Tensor", aten_add_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -98,7 +98,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.add.Tensor", aten_add_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::add_out(out, self, other, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.add.Scalar", aten_add_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -110,7 +110,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.add.Scalar", aten_add_Scalar, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::add_out(out_t, self, other, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten._add_relu.Tensor", aten__add_relu_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -123,7 +123,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten._add_relu.Tensor", aten__add_relu_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::add_relu_out(self, other, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addmv.default", aten_addmv_default, { const auto& self = KernelInput(0).toTensor(); @@ -138,7 +138,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addmv.default", aten_addmv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addmv_out(out, self, mat, vec, beta, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addr.default", aten_addr_default, { const auto& self = KernelInput(0).toTensor(); @@ -153,7 +153,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addr.default", aten_addr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::addr_out(self, vec1, vec2, beta, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.all.dim", aten_all_dim, { const auto& self = KernelInput(0).toTensor(); @@ -166,7 +166,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.all.dim", aten_all_dim, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::all_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.any.dim", aten_any_dim, { const auto& self = KernelInput(0).toTensor(); @@ -179,7 +179,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.any.dim", aten_any_dim, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::any_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.argmax.default", aten_argmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -192,7 +192,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.argmax.default", aten_argmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::argmax_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.acosh.default", aten_acosh_default, { const auto& self = KernelInput(0).toTensor(); @@ -203,7 +203,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.acosh.default", aten_acosh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::acosh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.asinh.default", aten_asinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -214,7 +214,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.asinh.default", aten_asinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::asinh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arcsinh.default", aten_arcsinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -225,7 +225,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arcsinh.default", aten_arcsinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arcsinh_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atanh.default", aten_atanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -236,7 +236,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atanh.default", aten_atanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atanh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctanh.default", aten_arctanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -247,7 +247,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctanh.default", aten_arctanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctanh_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.asin.default", aten_asin_default, { const auto& self = KernelInput(0).toTensor(); @@ -258,7 +258,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.asin.default", aten_asin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::asin_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arcsin.default", aten_arcsin_default, { const auto& self = KernelInput(0).toTensor(); @@ -269,7 +269,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arcsin.default", aten_arcsin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arcsin_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atan.default", aten_atan_default, { const auto& self = KernelInput(0).toTensor(); @@ -280,7 +280,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atan.default", aten_atan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atan_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctan.default", aten_arctan_default, { const auto& self = KernelInput(0).toTensor(); @@ -291,7 +291,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctan.default", aten_arctan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctan_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.baddbmm.default", aten_baddbmm_default, { const auto& self = KernelInput(0).toTensor(); @@ -306,7 +306,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.baddbmm.default", aten_baddbmm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::baddbmm_out(out, self, batch1, batch2, beta, alpha); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_not.default", @@ -320,7 +320,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_not_out(out, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.copysign.Tensor", aten_copysign_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -332,7 +332,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.copysign.Tensor", aten_copysign_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::copysign_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_not.default", @@ -346,7 +346,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_not_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_xor.default", @@ -361,7 +361,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_xor_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_and.default", @@ -376,7 +376,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_and_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_or.default", @@ -391,7 +391,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_or_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.ceil.default", aten_ceil_default, { const auto& self = KernelInput(0).toTensor(); @@ -402,7 +402,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ceil.default", aten_ceil_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ceil_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clamp.default", aten_clamp_default, { const auto& self = KernelInput(0).toTensor(); @@ -415,7 +415,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp.default", aten_clamp_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_out(out, self, min, max); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clamp.Tensor", aten_clamp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -428,7 +428,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp.Tensor", aten_clamp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_out(out, self, min, max); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.clamp_max.default", @@ -443,7 +443,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_max_out(out, self, max); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.clamp_max.Tensor", aten_clamp_max_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -455,7 +455,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp_max.Tensor", aten_clamp_max_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_max_out(out, self, max); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clip.default", aten_clip_default, { const auto& self = KernelInput(0).toTensor(); @@ -468,7 +468,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clip.default", aten_clip_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::clip_out(self, min, max, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.complex.default", aten_complex_default, { const auto& real = KernelInput(0).toTensor(); @@ -480,7 +480,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.complex.default", aten_complex_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::complex_out(real, imag, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.polar.default", aten_polar_default, { const auto& abs = KernelInput(0).toTensor(); @@ -492,7 +492,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.polar.default", aten_polar_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::polar_out(abs, angle, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cos.default", aten_cos_default, { const auto& self = KernelInput(0).toTensor(); @@ -503,7 +503,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cos.default", aten_cos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cos_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cosh.default", aten_cosh_default, { const auto& self = KernelInput(0).toTensor(); @@ -514,7 +514,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cosh.default", aten_cosh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cosh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cumprod.default", aten_cumprod_default, { const auto& self = KernelInput(0).toTensor(); @@ -527,7 +527,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cumprod.default", aten_cumprod_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cumprod_out(out, self, dim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.diff.default", aten_diff_default, { const auto& self = KernelInput(0).toTensor(); @@ -542,7 +542,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.diff.default", aten_diff_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::diff_out(self, n, dim, prepend, append, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor", aten_div_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -554,7 +554,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor", aten_div_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::div_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor_mode", aten_div_Tensor_mode, { const auto& self = KernelInput(0).toTensor(); @@ -567,7 +567,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor_mode", aten_div_Tensor_mode, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::div_out(out, self, other, rounding_mode); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.divide.Tensor", aten_divide_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -579,7 +579,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.divide.Tensor", aten_divide_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::divide_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.true_divide.Tensor", @@ -594,7 +594,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::true_divide_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.dot.default", aten_dot_default, { const auto& self = KernelInput(0).toTensor(); @@ -606,7 +606,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.dot.default", aten_dot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::dot_out(self, tensor, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.vdot.default", aten_vdot_default, { const auto& self = KernelInput(0).toTensor(); @@ -618,7 +618,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.vdot.default", aten_vdot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::vdot_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.erf.default", aten_erf_default, { const auto& self = KernelInput(0).toTensor(); @@ -629,7 +629,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erf.default", aten_erf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erf_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.erfc.default", aten_erfc_default, { const auto& self = KernelInput(0).toTensor(); @@ -640,7 +640,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erfc.default", aten_erfc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erfc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.exp.default", aten_exp_default, { const auto& self = KernelInput(0).toTensor(); @@ -651,7 +651,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.exp.default", aten_exp_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::exp_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.exp2.default", aten_exp2_default, { const auto& self = KernelInput(0).toTensor(); @@ -662,7 +662,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.exp2.default", aten_exp2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::exp2_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.expm1.default", aten_expm1_default, { const auto& self = KernelInput(0).toTensor(); @@ -673,7 +673,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.expm1.default", aten_expm1_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::expm1_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.floor.default", aten_floor_default, { const auto& self = KernelInput(0).toTensor(); @@ -684,7 +684,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.floor.default", aten_floor_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::floor_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.frac.default", aten_frac_default, { const auto& self = KernelInput(0).toTensor(); @@ -695,7 +695,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.frac.default", aten_frac_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::frac_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gcd.default", aten_gcd_default, { const auto& self = KernelInput(0).toTensor(); @@ -707,7 +707,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gcd.default", aten_gcd_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gcd_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lcm.default", aten_lcm_default, { const auto& self = KernelInput(0).toTensor(); @@ -719,7 +719,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lcm.default", aten_lcm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lcm_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.index_copy.default", @@ -736,7 +736,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::index_copy_out(out, self, dim, index, source); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Tensor_Tensor", @@ -754,7 +754,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, elements, test_elements, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Tensor_Scalar", @@ -772,7 +772,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, elements, test_element, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Scalar_Tensor", @@ -790,7 +790,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, element, test_elements, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.kron.default", aten_kron_default, { const auto& self = KernelInput(0).toTensor(); @@ -802,7 +802,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.kron.default", aten_kron_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::kron_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ldexp.Tensor", aten_ldexp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -814,7 +814,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ldexp.Tensor", aten_ldexp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::ldexp_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log10.default", aten_log10_default, { const auto& self = KernelInput(0).toTensor(); @@ -825,7 +825,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log10.default", aten_log10_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log10_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log1p.default", aten_log1p_default, { const auto& self = KernelInput(0).toTensor(); @@ -836,7 +836,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log1p.default", aten_log1p_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log1p_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log2.default", aten_log2_default, { const auto& self = KernelInput(0).toTensor(); @@ -847,7 +847,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log2.default", aten_log2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log2_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.logaddexp.default", @@ -862,7 +862,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::logaddexp_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logaddexp2.default", @@ -877,7 +877,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::logaddexp2_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.xlogy.Tensor", aten_xlogy_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -889,7 +889,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.xlogy.Tensor", aten_xlogy_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::xlogy_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten._log_softmax.default", @@ -905,7 +905,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_log_softmax_out(out, self, dim, half_to_float); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._logcumsumexp.default", @@ -920,7 +920,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::_logcumsumexp_out_cpu(self, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logcumsumexp.default", @@ -935,7 +935,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logcumsumexp_out(self, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.matrix_power.default", @@ -950,7 +950,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::matrix_power_out(self, n, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mm.default", aten_mm_default, { const auto& self = KernelInput(0).toTensor(); @@ -962,7 +962,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mm.default", aten_mm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mm_out(out, self, mat2); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.multiply.Tensor", aten_multiply_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -974,7 +974,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.multiply.Tensor", aten_multiply_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::multiply_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mv.default", aten_mv_default, { const auto& self = KernelInput(0).toTensor(); @@ -986,7 +986,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mv.default", aten_mv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::mv_out(self, vec, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mvlgamma.default", aten_mvlgamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -998,7 +998,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mvlgamma.default", aten_mvlgamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::mvlgamma_out(self, p, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.rad2deg.default", aten_rad2deg_default, { const auto& self = KernelInput(0).toTensor(); @@ -1009,7 +1009,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.rad2deg.default", aten_rad2deg_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::rad2deg_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.deg2rad.default", aten_deg2rad_default, { const auto& self = KernelInput(0).toTensor(); @@ -1020,7 +1020,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.deg2rad.default", aten_deg2rad_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::deg2rad_out(self, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.reciprocal.default", @@ -1034,7 +1034,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::reciprocal_out(out, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.neg.default", aten_neg_default, { const auto& self = KernelInput(0).toTensor(); @@ -1045,7 +1045,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.neg.default", aten_neg_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::neg_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.negative.default", aten_negative_default, { const auto& self = KernelInput(0).toTensor(); @@ -1056,7 +1056,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.negative.default", aten_negative_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::negative_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.round.default", aten_round_default, { const auto& self = KernelInput(0).toTensor(); @@ -1067,7 +1067,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.round.default", aten_round_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::round_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.round.decimals", aten_round_decimals, { const auto& self = KernelInput(0).toTensor(); @@ -1079,7 +1079,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.round.decimals", aten_round_decimals, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::round_out(out, self, decimals); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gelu.default", aten_gelu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1091,7 +1091,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gelu.default", aten_gelu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gelu_out(out, self, approximate); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardshrink.default", @@ -1106,7 +1106,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hardshrink_out(out, self, lambd); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.hardshrink_backward.default", @@ -1122,7 +1122,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::hardshrink_backward_out(grad_input, grad_out, self, lambd); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.rsqrt.default", aten_rsqrt_default, { const auto& self = KernelInput(0).toTensor(); @@ -1133,7 +1133,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.rsqrt.default", aten_rsqrt_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::rsqrt_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.silu.default", aten_silu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1144,7 +1144,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.silu.default", aten_silu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::silu_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.silu_backward.default", @@ -1159,7 +1159,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::silu_backward_out(grad_input, grad_output, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mish.default", aten_mish_default, { const auto& self = KernelInput(0).toTensor(); @@ -1170,7 +1170,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mish.default", aten_mish_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mish_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sigmoid.default", aten_sigmoid_default, { const auto& self = KernelInput(0).toTensor(); @@ -1181,7 +1181,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sigmoid.default", aten_sigmoid_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sigmoid_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sin.default", aten_sin_default, { const auto& self = KernelInput(0).toTensor(); @@ -1192,7 +1192,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sin.default", aten_sin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sin_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sinc.default", aten_sinc_default, { const auto& self = KernelInput(0).toTensor(); @@ -1203,7 +1203,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sinc.default", aten_sinc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sinc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sinh.default", aten_sinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -1214,7 +1214,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sinh.default", aten_sinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sinh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten._softmax.default", aten__softmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -1227,7 +1227,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten._softmax.default", aten__softmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_softmax_out(out, self, dim, half_to_float); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sqrt.default", aten_sqrt_default, { const auto& self = KernelInput(0).toTensor(); @@ -1238,7 +1238,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sqrt.default", aten_sqrt_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sqrt_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.square.default", aten_square_default, { const auto& self = KernelInput(0).toTensor(); @@ -1249,7 +1249,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.square.default", aten_square_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::square_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.prod.default", aten_prod_default, { const auto& self = KernelInput(0).toTensor(); @@ -1261,7 +1261,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.prod.default", aten_prod_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::prod_out(self, dtype, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.prod.dim_int", aten_prod_dim_int, { const auto& self = KernelInput(0).toTensor(); @@ -1275,7 +1275,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.prod.dim_int", aten_prod_dim_int, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::prod_out(out, self, dim, keepdim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.tan.default", aten_tan_default, { const auto& self = KernelInput(0).toTensor(); @@ -1286,7 +1286,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tan.default", aten_tan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tan_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.tanh.default", aten_tanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -1297,7 +1297,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tanh.default", aten_tanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tanh_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.threshold.default", @@ -1313,7 +1313,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::threshold_out(out, self, threshold, value); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.threshold_backward.default", @@ -1330,7 +1330,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::threshold_backward_out(grad_input, grad_output, self, threshold); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.trunc.default", aten_trunc_default, { const auto& self = KernelInput(0).toTensor(); @@ -1341,7 +1341,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.trunc.default", aten_trunc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::trunc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fix.default", aten_fix_default, { const auto& self = KernelInput(0).toTensor(); @@ -1352,7 +1352,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fix.default", aten_fix_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::fix_out(self, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nuclear_norm.default", @@ -1367,7 +1367,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::nuclear_norm_out(self, keepdim, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.subtract.Tensor", aten_subtract_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1380,7 +1380,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.subtract.Tensor", aten_subtract_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::subtract_out(self, other, alpha, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.heaviside.default", @@ -1395,7 +1395,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::heaviside_out(out, self, values); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._addmm_activation.default", @@ -1416,7 +1416,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::_addmm_activation_out( out, self, mat1, mat2, beta, alpha, use_gelu); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.index_add.default", @@ -1434,7 +1434,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::index_add_out(out, self, dim, index, source, alpha); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.src", aten_scatter_src, { const auto& self = KernelInput(0).toTensor(); @@ -1448,7 +1448,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.src", aten_scatter_src, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, src); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.value", aten_scatter_value, { const auto& self = KernelInput(0).toTensor(); @@ -1462,7 +1462,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.value", aten_scatter_value, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, value); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.reduce", aten_scatter_reduce, { const auto& self = KernelInput(0).toTensor(); @@ -1477,7 +1477,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.reduce", aten_scatter_reduce, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, src, reduce); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter.value_reduce", @@ -1495,7 +1495,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, value, reduce); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter_add.default", @@ -1512,7 +1512,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_add_out(out, self, dim, index, src); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter_reduce.two", @@ -1533,7 +1533,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::scatter_reduce_out( out, self, dim, index, src, reduce, include_self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.eq.Scalar", aten_eq_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1545,7 +1545,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.eq.Scalar", aten_eq_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::eq_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.eq.Tensor", aten_eq_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1557,7 +1557,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.eq.Tensor", aten_eq_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::eq_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_and.Tensor", @@ -1572,7 +1572,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_and_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_or.Tensor", @@ -1587,7 +1587,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_or_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_xor.Tensor", @@ -1602,7 +1602,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_xor_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_left_shift.Tensor", @@ -1617,7 +1617,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_left_shift_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_right_shift.Tensor", @@ -1632,7 +1632,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_right_shift_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.tril.default", aten_tril_default, { const auto& self = KernelInput(0).toTensor(); @@ -1644,7 +1644,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tril.default", aten_tril_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tril_out(out, self, diagonal); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.triu.default", aten_triu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1656,7 +1656,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.triu.default", aten_triu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::triu_out(out, self, diagonal); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.digamma.default", aten_digamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -1667,7 +1667,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.digamma.default", aten_digamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::digamma_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Scalar", aten_lerp_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1680,7 +1680,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Scalar", aten_lerp_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lerp_out(out, self, end, weight); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Tensor", aten_lerp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1693,7 +1693,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Tensor", aten_lerp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lerp_out(out, self, end, weight); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addbmm.default", aten_addbmm_default, { const auto& self = KernelInput(0).toTensor(); @@ -1708,7 +1708,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addbmm.default", aten_addbmm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::addbmm_out(self, batch1, batch2, beta, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cross.default", aten_cross_default, { const auto& self = KernelInput(0).toTensor(); @@ -1721,7 +1721,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cross.default", aten_cross_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cross_out(self, other, dim, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ne.Scalar", aten_ne_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1733,7 +1733,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ne.Scalar", aten_ne_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ne_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ne.Tensor", aten_ne_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1745,7 +1745,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ne.Tensor", aten_ne_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ne_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ge.Scalar", aten_ge_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1757,7 +1757,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ge.Scalar", aten_ge_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ge_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ge.Tensor", aten_ge_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1769,7 +1769,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ge.Tensor", aten_ge_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ge_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.le.Scalar", aten_le_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1781,7 +1781,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.le.Scalar", aten_le_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::le_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.le.Tensor", aten_le_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1793,7 +1793,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.le.Tensor", aten_le_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::le_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gt.Scalar", aten_gt_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1805,7 +1805,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gt.Scalar", aten_gt_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gt.Tensor", aten_gt_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1817,7 +1817,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gt.Tensor", aten_gt_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lt.Scalar", aten_lt_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1829,7 +1829,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lt.Scalar", aten_lt_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lt.Tensor", aten_lt_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1841,7 +1841,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lt.Tensor", aten_lt_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.take.default", aten_take_default, { const auto& self = KernelInput(0).toTensor(); @@ -1853,7 +1853,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.take.default", aten_take_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::take_out(self, index, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.take_along_dim.default", @@ -1869,7 +1869,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::take_along_dim_out(self, indices, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.masked_select.default", @@ -1884,7 +1884,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::masked_select_out_cpu(self, mask, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.gather.default", aten_gather_default, { const auto& self = KernelInput(0).toTensor(); @@ -1898,7 +1898,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gather.default", aten_gather_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gather_out(out, self, dim, index, sparse_grad); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addcmul.default", aten_addcmul_default, { const auto& self = KernelInput(0).toTensor(); @@ -1912,7 +1912,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addcmul.default", aten_addcmul_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addcmul_out(out, self, tensor1, tensor2, value); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addcdiv.default", aten_addcdiv_default, { const auto& self = KernelInput(0).toTensor(); @@ -1926,7 +1926,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addcdiv.default", aten_addcdiv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addcdiv_out(out, self, tensor1, tensor2, value); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_solve_triangular.default", @@ -1946,7 +1946,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::native::linalg_solve_triangular_out( self, B, upper, left, unitriangular, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.cholesky_solve.default", @@ -1962,7 +1962,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cholesky_solve_out(self, input2, upper, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.cholesky_inverse.default", @@ -1977,7 +1977,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cholesky_inverse_out(self, upper, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.orgqr.default", aten_orgqr_default, { const auto& self = KernelInput(0).toTensor(); @@ -1989,7 +1989,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.orgqr.default", aten_orgqr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::orgqr_out(self, input2, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ormqr.default", aten_ormqr_default, { const auto& self = KernelInput(0).toTensor(); @@ -2004,7 +2004,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ormqr.default", aten_ormqr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::ormqr_out(self, input2, input3, left, transpose, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lgamma.default", aten_lgamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -2015,7 +2015,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lgamma.default", aten_lgamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lgamma_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.polygamma.default", @@ -2030,7 +2030,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::polygamma_out(out, n, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.erfinv.default", aten_erfinv_default, { const auto& self = KernelInput(0).toTensor(); @@ -2041,7 +2041,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erfinv.default", aten_erfinv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erfinv_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.i0.default", aten_i0_default, { const auto& self = KernelInput(0).toTensor(); @@ -2052,7 +2052,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.i0.default", aten_i0_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::i0_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.signbit.default", aten_signbit_default, { const auto& self = KernelInput(0).toTensor(); @@ -2063,7 +2063,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.signbit.default", aten_signbit_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::signbit_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atan2.default", aten_atan2_default, { const auto& self = KernelInput(0).toTensor(); @@ -2075,7 +2075,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atan2.default", aten_atan2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atan2_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctan2.default", aten_arctan2_default, { const auto& self = KernelInput(0).toTensor(); @@ -2087,7 +2087,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctan2.default", aten_arctan2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctan2_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.histc.default", aten_histc_default, { const auto& self = KernelInput(0).toTensor(); @@ -2101,7 +2101,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.histc.default", aten_histc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::histogram_histc_out(self, bins, min, max, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fmod.Tensor", aten_fmod_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -2113,7 +2113,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmod.Tensor", aten_fmod_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmod_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.hypot.default", aten_hypot_default, { const auto& self = KernelInput(0).toTensor(); @@ -2125,7 +2125,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.hypot.default", aten_hypot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hypot_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.igamma.default", aten_igamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -2137,7 +2137,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.igamma.default", aten_igamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::igamma_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.igammac.default", aten_igammac_default, { const auto& self = KernelInput(0).toTensor(); @@ -2149,7 +2149,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.igammac.default", aten_igammac_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::igammac_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nextafter.default", @@ -2164,7 +2164,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::nextafter_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.fmin.default", aten_fmin_default, { const auto& self = KernelInput(0).toTensor(); @@ -2176,7 +2176,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmin.default", aten_fmin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmin_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fmax.default", aten_fmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -2188,7 +2188,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmax.default", aten_fmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmax_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.maximum.default", aten_maximum_default, { const auto& self = KernelInput(0).toTensor(); @@ -2200,7 +2200,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.maximum.default", aten_maximum_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::maximum_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.minimum.default", aten_minimum_default, { const auto& self = KernelInput(0).toTensor(); @@ -2212,7 +2212,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.minimum.default", aten_minimum_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::minimum_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.min.other", aten_min_other, { const auto& self = KernelInput(0).toTensor(); @@ -2224,7 +2224,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.min.other", aten_min_other, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::min_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.quantile.default", aten_quantile_default, { const auto& self = KernelInput(0).toTensor(); @@ -2240,7 +2240,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.quantile.default", aten_quantile_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::quantile_out(self, q, dim, keepdim, interpolation, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nanquantile.default", @@ -2259,7 +2259,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::nanquantile_out(self, q, dim, keepdim, interpolation, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.msort.default", aten_msort_default, { const auto& self = KernelInput(0).toTensor(); @@ -2270,7 +2270,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.msort.default", aten_msort_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::msort_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.all.default", aten_all_default, { const auto& self = KernelInput(0).toTensor(); @@ -2281,7 +2281,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.all.default", aten_all_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::all_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.any.default", aten_any_default, { const auto& self = KernelInput(0).toTensor(); @@ -2292,7 +2292,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.any.default", aten_any_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::any_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.renorm.default", aten_renorm_default, { const auto& self = KernelInput(0).toTensor(); @@ -2306,7 +2306,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.renorm.default", aten_renorm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::renorm_out(out, self, p, dim, maxnorm); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten._convert_indices_from_coo_to_csr.default", @@ -2323,7 +2323,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_convert_indices_from_coo_to_csr_out(out, self, size, out_int32); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._convert_indices_from_csr_to_coo.default", @@ -2342,7 +2342,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::_convert_indices_from_csr_to_coo_out( out, crow_indices, col_indices, out_int32, transpose); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mse_loss.default", aten_mse_loss_default, { const auto& self = KernelInput(0).toTensor(); @@ -2355,7 +2355,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mse_loss.default", aten_mse_loss_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mse_loss_out(out, self, target, reduction); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.multi_margin_loss.default", @@ -2376,7 +2376,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::native::multi_margin_loss_cpu_out( self, target, p, margin, weight, reduction, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.multilabel_margin_loss.default", @@ -2393,7 +2393,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::multilabel_margin_loss_out(self, target, reduction, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.soft_margin_loss.default", @@ -2409,7 +2409,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::soft_margin_loss_out(self, target, reduction, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.elu.default", aten_elu_default, { const auto& self = KernelInput(0).toTensor(); @@ -2423,7 +2423,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.elu.default", aten_elu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::elu_out(out, self, alpha, scale, input_scale); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.elu_backward.default", @@ -2450,7 +2450,7 @@ REGISTER_CPU_KERNEL( input_scale, is_result, self_or_result); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.glu.default", aten_glu_default, { const auto& self = KernelInput(0).toTensor(); @@ -2462,7 +2462,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.glu.default", aten_glu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::glu_out(out, self, dim); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardsigmoid.default", @@ -2476,7 +2476,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hardsigmoid_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.hardsigmoid_backward.default", @@ -2491,7 +2491,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::hardsigmoid_backward_out(grad_input, grad_output, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.hardtanh.default", aten_hardtanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -2504,7 +2504,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.hardtanh.default", aten_hardtanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::hardtanh_out(self, min_val, max_val, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardswish.default", @@ -2518,7 +2518,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::hardswish_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.leaky_relu_backward.default", @@ -2537,7 +2537,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(grad_input); at::cpu::leaky_relu_backward_out( grad_input, grad_output, self, negative_slope, self_is_result); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.log_sigmoid.default", @@ -2551,7 +2551,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::log_sigmoid_out(self, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.softplus.default", aten_softplus_default, { const auto& self = KernelInput(0).toTensor(); @@ -2564,7 +2564,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.softplus.default", aten_softplus_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::softplus_out(out, self, beta, threshold); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.softplus_backward.default", @@ -2583,7 +2583,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(grad_input); at::cpu::softplus_backward_out( grad_input, grad_output, self, beta, threshold); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.softshrink.default", @@ -2598,7 +2598,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::softshrink_out(out, self, lambd); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.softshrink_backward.default", @@ -2615,7 +2615,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::softshrink_backward_out(grad_input, grad_output, self, lambd); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.isposinf.default", aten_isposinf_default, { const auto& self = KernelInput(0).toTensor(); @@ -2626,7 +2626,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.isposinf.default", aten_isposinf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isposinf_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.isneginf.default", aten_isneginf_default, { const auto& self = KernelInput(0).toTensor(); @@ -2637,7 +2637,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.isneginf.default", aten_isneginf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isneginf_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.special_entr.default", @@ -2651,7 +2651,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_entr_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_ndtri.default", @@ -2665,7 +2665,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_ndtri_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_log_ndtr.default", @@ -2679,7 +2679,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_log_ndtr_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_expm1.default", @@ -2693,7 +2693,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_expm1_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_exp2.default", @@ -2707,7 +2707,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_exp2_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_psi.default", @@ -2721,7 +2721,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_psi_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_digamma.default", @@ -2735,7 +2735,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_digamma_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammaln.default", @@ -2749,7 +2749,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammaln_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erf.default", @@ -2763,7 +2763,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erf_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfc.default", @@ -2777,7 +2777,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erfc_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfcx.default", @@ -2791,7 +2791,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_erfcx_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfinv.default", @@ -2805,7 +2805,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erfinv_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_ndtr.default", @@ -2819,7 +2819,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_ndtr_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_xlog1py.default", @@ -2834,7 +2834,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_xlog1py_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_xlogy.default", @@ -2849,7 +2849,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_xlogy_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_zeta.default", @@ -2864,7 +2864,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_zeta_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i0.default", @@ -2878,7 +2878,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_i0_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i0e.default", @@ -2892,7 +2892,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i0e_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i1.default", @@ -2906,7 +2906,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i1_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i1e.default", @@ -2920,7 +2920,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i1e_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_polygamma.default", @@ -2935,7 +2935,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_polygamma_out(n, self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_expit.default", @@ -2949,7 +2949,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_expit_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_sinc.default", @@ -2963,7 +2963,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_sinc_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_round.default", @@ -2978,7 +2978,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_round_out(self, decimals, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_log1p.default", @@ -2992,7 +2992,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_log1p_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammainc.default", @@ -3007,7 +3007,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammainc_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammaincc.default", @@ -3022,7 +3022,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammaincc_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_multigammaln.default", @@ -3037,7 +3037,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_multigammaln_out(self, p, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_cross.default", @@ -3053,7 +3053,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::linalg_cross_out(out, self, other, dim); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_det.default", @@ -3067,7 +3067,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_det_out(A, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_matmul.default", @@ -3082,7 +3082,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_matmul_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_eigvals.default", @@ -3096,7 +3096,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_eigvals_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_inv.default", @@ -3110,7 +3110,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_inv_out(A, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.inverse.default", aten_inverse_default, { const auto& self = KernelInput(0).toTensor(); @@ -3121,7 +3121,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.inverse.default", aten_inverse_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::inverse_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.inner.default", aten_inner_default, { const auto& self = KernelInput(0).toTensor(); @@ -3133,7 +3133,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.inner.default", aten_inner_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::inner_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.outer.default", aten_outer_default, { const auto& self = KernelInput(0).toTensor(); @@ -3145,7 +3145,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.outer.default", aten_outer_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::outer_out(self, vec2, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_cond.default", @@ -3160,7 +3160,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_cond_out(self, p, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_solve.default", @@ -3176,7 +3176,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_solve_out(A, B, left, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_tensorinv.default", @@ -3191,7 +3191,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_tensorinv_out(self, ind, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_matrix_power.default", @@ -3206,6 +3206,6 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_matrix_power_out(self, n, out); - }); + }) } // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelRegistry.cpp b/torch/nativert/kernels/KernelRegistry.cpp index 2632b7886804c..77da29528d45b 100644 --- a/torch/nativert/kernels/KernelRegistry.cpp +++ b/torch/nativert/kernels/KernelRegistry.cpp @@ -245,7 +245,7 @@ C10_DEFINE_REGISTRY( StaticallyDispatchedCPUKernelRegistry, OpKernel, const Node*, - c10::Device); + c10::Device) namespace { @@ -294,7 +294,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Tensor", aten_remainder_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::remainder_out(out, self, KernelInput(1).toTensor()); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Scalar", aten_remainder_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -305,7 +305,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Scalar", aten_remainder_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::remainder_out(self, KernelInput(1).toScalar(), out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.matmul.default", aten_matmul, { const auto& in0_t = KernelInput(0).toTensor(); @@ -318,7 +318,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.matmul.default", aten_matmul, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::matmul_out(in0_t, in1_t, out_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.bmm.default", aten_bmm, { const auto& in0_t = KernelInput(0).toTensor(); @@ -329,7 +329,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.bmm.default", aten_bmm, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::bmm_out(out_t, in0_t, in1_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.abs.default", aten_abs, { const auto& in0_t = KernelInput(0).toTensor(); @@ -340,7 +340,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.abs.default", aten_abs, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::abs_out(in0_t, out_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mul.Tensor", aten_mul, { const auto& in0_t = KernelInput(0).toTensor(); @@ -352,7 +352,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mul.Tensor", aten_mul, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::mul_out(out_t, in0_t, in1_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mul.Scalar", aten_mul_Scalar, { const auto& in0_t = KernelInput(0).toTensor(); @@ -364,7 +364,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mul.Scalar", aten_mul_Scalar, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); KernelOutput(0) = at::native::mul_out(out_t, in0_t, in1_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.nan_to_num.default", aten_nan_to_num, { const auto& in0_t = KernelInput(0).toTensor(); @@ -378,7 +378,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.nan_to_num.default", aten_nan_to_num, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.leaky_relu.default", aten_leaky_relu, { const auto& in0_t = KernelInput(0).toTensor(); @@ -389,7 +389,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.leaky_relu.default", aten_leaky_relu, { } auto& out_t = KernelOutput(0).toTensor(); at::cpu::leaky_relu_out(out_t, in0_t, in1_s); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.relu.default", aten_relu, { const auto& in0_t = KernelInput(0).toTensor(); @@ -399,7 +399,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.relu.default", aten_relu, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::threshold_out(out_t, in0_t, 0, 0); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clone.default", aten_clone, { const auto& src = KernelInput(0).toTensor(); @@ -433,7 +433,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clone.default", aten_clone, { at::native::resize_impl_cpu_( out_t.unsafeGetTensorImpl(), src.sizes(), src.strides()); at::native::copy_(out_t, src, false); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.index.Tensor", aten_index, { const auto& in0_t = KernelInput(0).toTensor(); @@ -446,7 +446,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.index.Tensor", aten_index, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::index_out(out_t, in0_t, in1_l); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.index_select.default", aten_index_select, { const auto& self = KernelInput(0).toTensor(); @@ -459,7 +459,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.index_select.default", aten_index_select, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::index_select_out_cpu_(self, dim, index, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.pow.Tensor_Tensor", @@ -474,7 +474,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out_t); at::cpu::pow_out( out_t, KernelInput(0).toTensor(), KernelInput(1).toTensor()); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.pow.Scalar", aten_pow_Scalar, { if (KernelOutput(0).isNone()) { @@ -491,7 +491,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.pow.Scalar", aten_pow_Scalar, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::pow_out(out_t, KernelInput(0).toScalar(), KernelInput(1).toTensor()); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.pow.Tensor_Scalar", @@ -512,7 +512,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out_t); at::cpu::pow_out( out_t, KernelInput(0).toTensor(), KernelInput(1).toScalar()); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.sum.default", aten_sum_default, { // if (n->inputs().size() != 2 && n->inputs().size() != 4) { @@ -529,7 +529,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sum.default", aten_sum_default, { fastResizeToZero(out); at::cpu::sum_out(out, self, dim, keepdim, dtype); } -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sum.dim_IntList", aten_sum_dim_IntList, { // if (n->inputs().size() != 2 && n->inputs().size() != 4) { @@ -546,7 +546,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sum.dim_IntList", aten_sum_dim_IntList, { fastResizeToZero(out); at::cpu::sum_out(out, self, dim, keepdim, dtype); } -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mean.dim", aten_mean_dim, { const auto& self = KernelInput(0).toTensor(); @@ -560,7 +560,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mean.dim", aten_mean_dim, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mean_out(out, self, dim, keepdim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mean.default", aten_mean_default, { const auto& self = KernelInput(0).toTensor(); @@ -572,7 +572,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mean.default", aten_mean_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mean_out(out, self, /*dim=*/{}, /*keepdim=*/false, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.max.other", aten_max_other, { const auto& self = KernelInput(0).toTensor(); @@ -584,7 +584,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.max.other", aten_max_other, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::max_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.max.default", aten_max_default, { const auto& self = KernelInput(0).toTensor(); @@ -594,7 +594,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.max.default", aten_max_default, { auto& value = KernelOutput(0).toTensor(); fastResizeToZero(value); at::cpu::amax_out(value, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sign.Tensor", aten_sign_Tensor, { const auto& in0_t = KernelInput(0).toTensor(); @@ -605,7 +605,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sign.Tensor", aten_sign_Tensor, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::sign_out(out_t, in0_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log.default", aten_log, { const auto& in0_t = KernelInput(0).toTensor(); @@ -616,7 +616,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log.default", aten_log, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::log_out(out_t, in0_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sub.Tensor", aten_sub_Tensor, { const auto& in0_t = KernelInput(0).toTensor(); @@ -629,7 +629,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sub.Tensor", aten_sub_Tensor, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::sub_out(out_t, in0_t, in1_t, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sub.Scalar", aten_sub, { const auto& in0_t = KernelInput(0).toTensor(); @@ -643,7 +643,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sub.Scalar", aten_sub, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::sub_out(out_t, in0_t, in1_t, alpha); -}); +}) // TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor // Missing Test Coverage @@ -660,7 +660,7 @@ REGISTER_CPU_KERNEL( auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::clamp_min_out(out_t, in0_t, in1_s); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.argmin.default", aten_argmin, { const auto& in0_t = KernelInput(0).toTensor(); @@ -677,7 +677,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.argmin.default", aten_argmin, { return; } at::cpu::argmin_out(out_t, in0_t, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.softmax.int", aten_softmax_int, { const auto& in_t = KernelInput(0).toTensor(); @@ -692,7 +692,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.softmax.int", aten_softmax_int, { auto half_to_float = in_t.scalar_type() == at::ScalarType::Half && dtype == at::ScalarType::Float; at::cpu::_softmax_out(out_t, in_t, dim, half_to_float); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.norm.ScalarOpt_dtype", @@ -712,7 +712,7 @@ REGISTER_CPU_KERNEL( false, KernelInput(2).toScalarType(), out_t); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.full.default", aten_full, { const auto& size = KernelInput(0).toDimVector(); @@ -728,7 +728,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.full.default", aten_full, { } KernelOutput(0) = at::native::full_out(size, fill_value, KernelOutput(0).toTensor()); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ones.default", aten_ones, { const auto size = KernelInput(0).toDimVector(); @@ -743,7 +743,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ones.default", aten_ones, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::ones_out(size, out_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ones_like.default", aten_ones_like, { const auto& self = KernelInput(0).toTensor(); @@ -760,7 +760,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ones_like.default", aten_ones_like, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::ones_out(self.sizes(), out_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.zeros.default", aten_zeros, { const auto size = KernelInput(0).toDimVector(); @@ -774,7 +774,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.zeros.default", aten_zeros, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::compositeexplicitautograd::zeros_out(out_t, size); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_norm.default", @@ -798,7 +798,7 @@ REGISTER_CPU_KERNEL( keepdim, dtype, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.linalg_norm.ord_str", aten_linalg_norm, { const auto& self = KernelInput(0).toTensor(); @@ -814,7 +814,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.linalg_norm.ord_str", aten_linalg_norm, { fastResizeToZero(out); at::native::linalg_norm_out( self, KernelInput(1).toStringRef(), dim, keepdim, dtype, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cat.default", aten_cat, { const auto inputs = KernelInput(0).toTensorVector(); @@ -827,7 +827,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cat.default", aten_cat, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cat_outf(inputs, dim, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cumsum.default", aten_cumsum, { const auto& self = KernelInput(0).toTensor(); @@ -840,7 +840,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cumsum.default", aten_cumsum, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cumsum_out(out, self, dim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.nonzero.default", aten_nonzero, { const auto& self = KernelInput(0).toTensor(); @@ -851,7 +851,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.nonzero.default", aten_nonzero, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::nonzero_out_cpu(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addmm.default", aten_addmm, { const auto& in0_t = KernelInput(0).toTensor(); @@ -866,7 +866,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addmm.default", aten_addmm, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.narrow_copy.default", aten_narrow_copy, { const auto& self = KernelInput(0).toTensor(); // self @@ -888,7 +888,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.narrow_copy.default", aten_narrow_copy, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::narrow_copy_dense_cpu_out(self, dim, start, length, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.repeat.default", aten_repeat, { const auto& self = KernelInput(0).toTensor(); @@ -900,7 +900,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.repeat.default", aten_repeat, { } at::Tensor& out = KernelOutput(0).toTensor(); at::native::repeat_out(out, self, repeats); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.max.dim", aten_max_dim, { const auto& self = KernelInput(0).toTensor(); @@ -920,7 +920,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.max.dim", aten_max_dim, { fastResizeToZero(values); fastResizeToZero(indices); at::cpu::max_out(values, indices, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.layer_norm.default", aten_layer_norm, { // ignore KernelInput(5): `bool cudnn_enable=True` @@ -956,7 +956,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.layer_norm.default", aten_layer_norm, { } at::Tensor& out = KernelOutput(0).toTensor(); at::native::layer_norm_cpu_out(out, *X, *gamma, *beta, eps, M, N); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.norm.ScalarOpt_dim_dtype", @@ -978,7 +978,7 @@ REGISTER_CPU_KERNEL( KernelInput(3).toBool(), // keepdim KernelInput(4).toScalarType(), // dtype out_t); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.norm.ScalarOpt_dim", @@ -999,7 +999,7 @@ REGISTER_CPU_KERNEL( KernelInput(2).toDimVector(), // dim KernelInput(3).toBool(), // keepdim out_t); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.full_like.default", aten_full_like, { const auto in1_s = KernelInput(1).toScalar(); @@ -1017,7 +1017,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.full_like.default", aten_full_like, { auto& out_t = KernelOutput(0).toTensor(); at::native::resize_(out_t, in0_t.sizes(), std::nullopt); at::native::fill_out(out_t, in1_s); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.linear.default", aten_linear, { const auto& in0_t = KernelInput(0).toTensor(); @@ -1031,7 +1031,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.linear.default", aten_linear, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::linear_out(out_t, in0_t, in1_t, in2_t); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.where.self", aten_where, { const auto& cond = KernelInput(0).toTensor(); @@ -1044,7 +1044,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.where.self", aten_where, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::where_self_out(cond, self, other, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.quantized.embedding_bag_byte_rowwise_offsets.default", @@ -1074,7 +1074,7 @@ REGISTER_CPU_KERNEL( per_sample_weights, compressed_indices_mapping, include_last_offset); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.embedding_bag_4bit_rowwise_offsets.default", @@ -1104,7 +1104,7 @@ REGISTER_CPU_KERNEL( per_sample_weights, compressed_indices_mapping, include_last_offset); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.linear_dynamic_fp16.default", @@ -1121,7 +1121,7 @@ REGISTER_CPU_KERNEL( KernelInput(1).toCustomClass()->apply_dynamic_out( in_0, out_0, /* reduce_range= */ false); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.linear_relu_dynamic_fp16.default", @@ -1140,7 +1140,7 @@ REGISTER_CPU_KERNEL( .toCustomClass() ->apply_dynamic_out(in_0, out_0, /* reduce_range= */ false) .relu_(); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.linear.default", @@ -1165,7 +1165,7 @@ REGISTER_CPU_KERNEL( auto& out_tensor = KernelOutput(0).toTensor(); fastResizeToZero(out_tensor); w_prepack->apply_out(in_0, output_scale, output_zero_point, out_tensor); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.logit.default", aten_logit, { const auto& in0_t = KernelInput(0).toTensor(); @@ -1176,7 +1176,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.logit.default", aten_logit, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::logit_out(in0_t, in1_d, out_t); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.slice_scatter.default", @@ -1194,7 +1194,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::slice_scatter_out(out, self, src, dim, start, end, step); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.embedding_bag_byte_unpack.default", @@ -1210,7 +1210,7 @@ REGISTER_CPU_KERNEL( auto& out_tensor = KernelOutput(0).toTensor(); fastResizeToZero(out_tensor); at::native::qembeddingbag_byte_unpack_out(out_tensor, weight); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.quantized.embedding_bag_byte_prepack.default", @@ -1224,7 +1224,7 @@ REGISTER_CPU_KERNEL( auto& out_tensor = KernelOutput(0).toTensor(); fastResizeToZero(out_tensor); at::native::qembeddingbag_byte_prepack_out(out_tensor, weight); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.stack.default", aten_stack, { const auto& inputs = KernelInput(0).toTensorVector(); @@ -1236,7 +1236,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.stack.default", aten_stack, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::native::_stack_out_cpu(inputs, dim, out_t); -}); +}) class OpKernel_aten__to_copy : public C10Kernel { public: diff --git a/torch/nativert/kernels/NativeKernels.cpp b/torch/nativert/kernels/NativeKernels.cpp index 1f847863070ac..7acd82102266b 100644 --- a/torch/nativert/kernels/NativeKernels.cpp +++ b/torch/nativert/kernels/NativeKernels.cpp @@ -13,7 +13,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.slice.Tensor", aten_slice_Tensor, { const auto& end = KernelInput(3).toOptional(); const auto& step = KernelInput(4).toInt(); KernelOutput(0) = at::native::slice(self, dim, start, end, step); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.sym_size.int", aten_sym_size_int, { const auto& self = KernelInput(0).toTensor(); @@ -21,39 +21,39 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.sym_size.int", aten_sym_size_int, { auto& out = KernelOutput(0); TORCH_CHECK(dim >= 0 && dim < self.dim(), "Invalid dimension"); out = self.sym_size(dim); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.reshape.default", aten_reshape, { const auto& self = KernelInput(0).toTensor(); const auto& shape = KernelInput(1).toIntVector(); KernelOutput(0) = at::native::reshape(self, shape); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.view.default", aten_view, { const auto& self = KernelInput(0).toTensor(); const auto& size = KernelInput(1).toIntVector(); KernelOutput(0) = at::native::view(self, size); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.permute.default", aten_permute, { const auto& self = KernelInput(0).toTensor(); const auto& dims = KernelInput(1).toDimVector(); KernelOutput(0) = at::native::permute(self, dims); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.select.int", aten_select, { const auto& self = KernelInput(0).toTensor(); const auto dim = KernelInput(1).toInt(); const auto index = KernelInput(2).toInt(); KernelOutput(0) = at::native::select(self, dim, index); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.split.Tensor", aten_split_Tensor, { const auto& self = KernelInput(0).toTensor(); const auto split_size = KernelInput(1).toInt(); const auto dim = KernelInput(2).toInt(); KernelOutput(0) = at::native::split(self, split_size, dim); -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.split_with_sizes.default", @@ -64,7 +64,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim = KernelInput(2).toInt(); KernelOutput(0) = at::native::split_with_sizes(self, split_sizes.vec(), dim); - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.tensor_split.sections", @@ -75,12 +75,12 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim = KernelInput(2).toInt(); KernelOutput(0) = at::native::tensor_split_sections_symint(self, sections, dim); - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.item.default", aten_item, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::item(self); -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.narrow.default", aten_narrow, { const auto& self = KernelInput(0).toTensor(); @@ -108,6 +108,6 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.narrow.default", aten_narrow, { cur_size, ")."); KernelOutput(0) = at::native::slice(self, dim, start, start + length, 1); -}); +}) } // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp index 80421bae77597..b9071c8ecc4e4 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.cpp +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -10,7 +10,7 @@ namespace torch::nativert { -C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); +C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*) namespace { @@ -65,11 +65,11 @@ REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) { KernelOutput(i) = ivalue; } -}); +}) // Noop for input and output -REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}); -REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}); +REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}) +REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}) namespace { diff --git a/torch/nativert/kernels/PrimKernelRegistry.h b/torch/nativert/kernels/PrimKernelRegistry.h index 89e9c29e7dcb5..f050ff79b86fe 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.h +++ b/torch/nativert/kernels/PrimKernelRegistry.h @@ -21,7 +21,7 @@ TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); __VA_ARGS__; \ } \ }; \ - C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id); + C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id) inline bool checkResizedDataPtr(at::Tensor& t) { auto const prev_data_ptr = t.data_ptr(); From 1a6b21c59f08e3c7ae2e22a866828e2fff21db68 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 22 Jul 2025 01:54:44 +0000 Subject: [PATCH 372/457] [AOTI] fix load_pt2 split wrong model name on Windows (#158711) fix load_pt2 split wrong model name on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158711 Approved by: https://github.com/jansel --- torch/export/pt2_archive/_package.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index f14087250d526..83cae4836d9a5 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -562,6 +562,8 @@ def load_pt2( A ``PT2ArchiveContents`` object which contains all the objects in the PT2. """ + from torch._inductor.cpp_builder import normalize_path_separator + if not ( (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) @@ -600,6 +602,9 @@ def load_pt2( file_end = file[ len(AOTINDUCTOR_DIR) : ] # remove data/aotinductor/ prefix + file_end = normalize_path_separator( + file_end + ) # Win32 need normalize path before split. model_name = file_end.split("/")[ 0 ] # split "model_name/...cpp" into "model_name" From eac777c4f46b381106f2f2b78fe05b506f8c558c Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 21 Jul 2025 11:26:01 -0700 Subject: [PATCH 373/457] [Inductor] Expose decomposeK knobs as envvars (#158745) Fix up decomposeK autotuning, by removing condition to return more than `k_splits_limit` and setting default to 10 instead of 5. Allow `k_splits_limit` to be configurable to the user via `TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS` and also allow user to configure threshold in which to use decompose_k via `TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158745 Approved by: https://github.com/eellison --- test/inductor/test_max_autotune.py | 55 +++++++++++++++++++++++------- torch/_inductor/config.py | 14 ++++++-- torch/_inductor/utils.py | 33 ++++++++---------- 3 files changed, 67 insertions(+), 35 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 6245b89f4eca3..a04017459fc69 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -50,7 +50,12 @@ aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_cache, run_and_get_code +from torch._inductor.utils import ( + fresh_cache, + get_k_splits, + run_and_get_code, + use_decompose_k_choice, +) from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck @@ -1498,6 +1503,7 @@ def misses(): self.assertEqual(hits(), 4) self.assertEqual(misses(), 4) + @fresh_cache() @skipIfXpu @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" @@ -1506,19 +1512,42 @@ def misses(): max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, - disable_decompose_k=True, ) - def test_max_autotune_disable_decompose_K(self): - M, N, K = (32, 32, 32768) - - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) - - compiled_func = torch.compile(lambda a, b: a @ b) - out, code = run_and_get_code(compiled_func, a, b) - - for codegen in code: - FileCheck().check_not("decompose_k").run(codegen) + @parametrize("num_decompose_k_splits", (0, 5, 20)) + @parametrize("decompose_k_threshold", (8, 16)) + def test_max_autotune_decompose_k_envvars( + self, num_decompose_k_splits, decompose_k_threshold + ): + shapes = [(32, 32, 32768), (32, 32, 256)] + for M, N, K in shapes: + get_k_splits.cache_clear() + use_decompose_k_choice.cache_clear() + a = torch.randn(M, K, dtype=torch.float16, device="cuda") + b = torch.randn(K, N, dtype=torch.float16, device="cuda") + + with config.patch( + { + "triton.num_decompose_k_splits": num_decompose_k_splits, + "triton.decompose_k_threshold": decompose_k_threshold, + } + ): + compiled_func = torch.compile(lambda a, b: a @ b) + _, code = run_and_get_code(compiled_func, a, b) + + decompose_count = 0 + for codegen in code: + if "benchmark_decompose_k_mm" in codegen: + decompose_count += 1 + + if ( + K // M < decompose_k_threshold + or K // N < decompose_k_threshold + or num_decompose_k_splits == 0 + ): + self.assertEqual(decompose_count, 0) + else: + self.assertTrue(decompose_count > 0) + self.assertTrue(decompose_count <= num_decompose_k_splits) @skipIfXpu @unittest.skipIf( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2ce07c6293233..e22f5459be3f2 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -425,9 +425,6 @@ def prologue_fusion_enabled() -> bool: # enable slow autotuning passes to select gemm algorithms max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" -# disable decomposek autotune choice for gemm -disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" - # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 @@ -1345,6 +1342,17 @@ class triton: # Note: it may also need to be used with config.compile_threads = 1 disallow_failing_autotune_kernels_TESTING_ONLY = False + # specify number of splits to autotune on for decompose_k. 0 disables decompose_k + num_decompose_k_splits = int( + os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") + ) + + # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables + # it as an autotuning choice for all matmuls + decompose_k_threshold = int( + os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") + ) + class aot_inductor: """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9bec3fd764bf3..bbacd23612643 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1664,20 +1664,15 @@ def _use_cutlass_for_op(op_name: str) -> bool: return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] -decompose_k_threshold = 32 - -# To limit compile time -k_splits_limit = 5 - -# Hand-tuned -default_k_splits = [16, 32, 64, 128, 256] - _IntLike: TypeAlias = Union[int, sympy.Expr] +@functools.cache def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: from torch._inductor.virtualized import V + decompose_k_threshold = config.triton.decompose_k_threshold + return ( not torch.version.hip and V.graph.sizevars.statically_known_true( @@ -1688,15 +1683,21 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode # TODO: Support AOTI for decomposeK and not V.graph.cpp_wrapper - and not config.disable_decompose_k ) @functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: + # To limit compile time + k_splits_limit = config.triton.num_decompose_k_splits + + # Hand-tuned + default_k_splits = [16, 32, 64, 128, 256] # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: return default_k_splits + elif k_splits_limit == 0: + return [] if (isinstance(m, sympy.Expr) and not m.is_number) or ( isinstance(n, sympy.Expr) and not n.is_number @@ -1736,15 +1737,10 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: if config.max_autotune_gemm_search_space == "EXHAUSTIVE": return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # If the # of power of 2 divisors are greater than k_splits_limit, return all - # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) - # should never be a massive amount - if len(pow_of_2_divisors) >= k_splits_limit: - return pow_of_2_divisors - else: - best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # Otherwise, conform results to k_splits_limit - return best_splits[:k_splits_limit] + + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] @functools.cache @@ -2019,7 +2015,6 @@ def call(self, *args: Any, **kwargs: Any) -> None: self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) # Skip all the actual compiling. - nonlocal save_output_code save_output_code(wrapper_code.value) if kernel_code: save_output_code(kernel_code.value) From aee8a2e98589886ee80767bcbd10c03d13fb19ec Mon Sep 17 00:00:00 2001 From: FFFrog Date: Wed, 16 Jul 2025 23:07:59 +0800 Subject: [PATCH 374/457] Remove duplicated installation for python dependencies. (#158339) As the title stated. The `Common` Section have installed the python dependencies https://github.com/pytorch/pytorch/blob/1b389025ba0cc640e07991314bfba8b6ca385bd2/README.md?plain=1#L247 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158339 Approved by: https://github.com/ezyang --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 24323032100d1..62e3b9ea49373 100644 --- a/README.md +++ b/README.md @@ -294,14 +294,12 @@ Install PyTorch ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" -python -m pip install -r requirements-build.txt python -m pip install --no-build-isolation -v -e . ``` **On macOS** ```bash -python -m pip install -r requirements-build.txt python -m pip install --no-build-isolation -v -e . ``` From 3639d29ea178c7c0e8be7ac55d4753772f428bc3 Mon Sep 17 00:00:00 2001 From: cyy Date: Tue, 22 Jul 2025 02:49:06 +0000 Subject: [PATCH 375/457] Fix warnings of unused-variable (#158627) Fixes ``` /var/lib/jenkins/workspace/test/cpp/tensorexpr/test_kernel.cpp:42:22: error: unused variable 'verification_pattern' [-Werror,-Wunused-variable] ``` and also extra semicolons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158627 Approved by: https://github.com/albanD --- test/cpp/tensorexpr/test_kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index f9cd82ff95d04..dc67928b111a0 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -38,12 +38,12 @@ TEST_F(Kernel, ParallelExternalCallBuf) { %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) return (%4))IR"; auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, &*graph); + torch::jit::parseIR(graph_string, graph.get()); +#ifdef TORCH_ENABLE_LLVM const std::string& verification_pattern = R"IR( # CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; -#ifdef TORCH_ENABLE_LLVM TensorExprKernel k(graph); StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; From a155f742adc5f5bf169ff9ac8bf2e98a22ddcacb Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 22 Jul 2025 03:07:22 +0000 Subject: [PATCH 376/457] [benchmark] allow default mode for compile (#158792) Allow default mode for compile when users cannot run "max-autotune-no-cudagraphs" due to compilation time. Overall, "default" mode is slower than "[max-autotune-no-cudagraphs](https://github.com/pytorch/pytorch/pull/158536)" depending on input shapes. CrossEntropyBackward_bench CrossEntropyForward_bench LayerNormBackward_bench LayerNormForward_bench RMSNormBackward_bench RMSNormForward_bench SoftmaxBackward_bench SoftmaxForward_bench Pull Request resolved: https://github.com/pytorch/pytorch/pull/158792 Approved by: https://github.com/zou3519 --- benchmarks/dynamo/genai_layers/benchmark.py | 25 +++++++--- benchmarks/dynamo/genai_layers/kernels.py | 52 +++++++++++---------- benchmarks/dynamo/genai_layers/utils.py | 6 ++- 3 files changed, 51 insertions(+), 32 deletions(-) diff --git a/benchmarks/dynamo/genai_layers/benchmark.py b/benchmarks/dynamo/genai_layers/benchmark.py index 0378629670524..70349ee444098 100644 --- a/benchmarks/dynamo/genai_layers/benchmark.py +++ b/benchmarks/dynamo/genai_layers/benchmark.py @@ -56,7 +56,11 @@ def list_benchmarks(): print(f"Available benchmarks: {list(BENCHMARK_REGISTRY.keys())}") -def run_benchmark(benchmark_name: str, should_visualize: bool = False): +def run_benchmark( + benchmark_name: str, + should_visualize: bool = False, + compile_mode: str = "max-autotune-no-cudagraphs", +): """Run a specific benchmark.""" if benchmark_name not in BENCHMARK_REGISTRY: print(f"Error: Unknown benchmark '{benchmark_name}'") @@ -64,10 +68,11 @@ def run_benchmark(benchmark_name: str, should_visualize: bool = False): return False print(f"Running benchmark: {benchmark_name}") + print(f"Torch compile mode: {compile_mode}") print("=" * 60) benchmark_class = BENCHMARK_REGISTRY[benchmark_name] - benchmark = benchmark_class() + benchmark = benchmark_class(compile_mode) benchmark.benchmark() if should_visualize: benchmark.visualize() @@ -75,14 +80,15 @@ def run_benchmark(benchmark_name: str, should_visualize: bool = False): return True -def run_all_benchmarks(should_visualize: bool = False): +def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"): """Run all available benchmarks.""" print("Running all benchmarks...") + print(f"Torch compile mode: {compile_mode}") print("=" * 60) for name, cls in BENCHMARK_REGISTRY.items(): print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") - benchmark = cls() + benchmark = cls(compile_mode) benchmark.benchmark() if should_visualize: benchmark.visualize() @@ -124,6 +130,13 @@ def main(): help="Visualize results after running benchmarks", ) + parser.add_argument( + "--compile-mode", + choices=["default", "max-autotune-no-cudagraphs"], + default="max-autotune-no-cudagraphs", + help="Torch compile mode to use (default: default)", + ) + args = parser.parse_args() # Handle list option @@ -133,7 +146,7 @@ def main(): # Handle all option if args.all: - run_all_benchmarks(args.visualize) + run_all_benchmarks(args.visualize, args.compile_mode) return # Handle specific benchmarks @@ -144,7 +157,7 @@ def main(): sys.exit(1) for benchmark_name in args.benchmarks: - run_benchmark(benchmark_name, args.visualize) + run_benchmark(benchmark_name, args.visualize, args.compile_mode) print() # Add spacing between benchmarks diff --git a/benchmarks/dynamo/genai_layers/kernels.py b/benchmarks/dynamo/genai_layers/kernels.py index 30a5f21eaef81..ee79f02761ed8 100644 --- a/benchmarks/dynamo/genai_layers/kernels.py +++ b/benchmarks/dynamo/genai_layers/kernels.py @@ -9,8 +9,8 @@ class CrossEntropyForward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -52,7 +52,8 @@ def compiled(self, args, kwargs=None) -> Any: # More discussion: https://github.com/pytorch/pytorch/issues/158455 compiled_cross_entropy = torch.compile( lambda x, target: F.cross_entropy(x, target, reduction="none"), - mode="max-autotune-no-cudagraphs", + mode=self.compile_mode, + fullgraph=True, ) return lambda: compiled_cross_entropy(x, target) @@ -105,8 +106,8 @@ def check_accuracy(self, args, kwargs) -> None: class CrossEntropyBackward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -149,7 +150,8 @@ def compiled(self, args, kwargs=None) -> Any: compiled_cross_entropy = torch.compile( lambda x, target: F.cross_entropy(x, target, reduction="none"), - mode="max-autotune-no-cudagraphs", + mode=self.compile_mode, + fullgraph=True, ) loss = compiled_cross_entropy(x, target) return lambda: torch.autograd.grad( @@ -192,8 +194,8 @@ def benchmark(self): class SoftmaxForward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -229,7 +231,7 @@ def compiled(self, args, kwargs=None) -> Any: torch._dynamo.mark_dynamic(x, 0) compiled_softmax = torch.compile( - lambda x: F.softmax(x, dim=-1), mode="max-autotune-no-cudagraphs" + lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True ) return lambda: compiled_softmax(x) @@ -257,8 +259,8 @@ def benchmark(self): class SoftmaxBackward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -292,7 +294,7 @@ def compiled(self, args, kwargs=None) -> Any: assert kwargs is None x, dy = args compiled_softmax = torch.compile( - lambda x: F.softmax(x, dim=-1), mode="max-autotune-no-cudagraphs" + lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True ) y = compiled_softmax(x) return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) @@ -327,8 +329,8 @@ def benchmark(self): class RMSNormForward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -372,7 +374,7 @@ def compiled(self, args, kwargs=None) -> Any: torch._dynamo.mark_dynamic(x, 0) compiled_rms_norm = torch.compile( - self.rms_norm_ref, mode="max-autotune-no-cudagraphs" + self.rms_norm_ref, mode=self.compile_mode, fullgraph=True ) return lambda: compiled_rms_norm(x, w) @@ -402,8 +404,8 @@ def benchmark(self): class RMSNormBackward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -445,7 +447,9 @@ def eager(self, args, kwargs=None) -> Any: def compiled(self, args, kwargs=None) -> Any: assert kwargs is None x, w, dy = args - y = torch.compile(self.rms_norm_ref, mode="max-autotune-no-cudagraphs")(x, w) + y = torch.compile(self.rms_norm_ref, mode=self.compile_mode, fullgraph=True)( + x, w + ) return lambda: torch.autograd.grad( y, [x, w], grad_outputs=dy, retain_graph=True ) @@ -485,8 +489,8 @@ def benchmark(self): class LayerNormForward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "quack", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -526,7 +530,7 @@ def compiled(self, args, kwargs=None) -> Any: torch._dynamo.mark_dynamic(x, 0) compiled_layernorm = torch.compile( - self.layernorm_ref, mode="max-autotune-no-cudagraphs" + self.layernorm_ref, mode=self.compile_mode, fullgraph=True ) return lambda: compiled_layernorm(x, w, eps=1e-6) @@ -559,8 +563,8 @@ def benchmark(self): class LayerNormBackward(BenchmarkKernel): - def __init__(self): - super().__init__() + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) self.available_backends = ["eager", "compiled", "liger"] def get_shapes(self) -> tuple[tuple[int, ...], ...]: @@ -603,7 +607,7 @@ def compiled(self, args, kwargs=None) -> Any: assert kwargs is None x, w, dy = args compiled_layernorm = torch.compile( - self.layernorm_ref, mode="max-autotune-no-cudagraphs" + self.layernorm_ref, mode=self.compile_mode, fullgraph=True ) y = compiled_layernorm(x, w) return lambda: torch.autograd.grad( diff --git a/benchmarks/dynamo/genai_layers/utils.py b/benchmarks/dynamo/genai_layers/utils.py index 9d3f97c0da749..e11995ee0b5f5 100644 --- a/benchmarks/dynamo/genai_layers/utils.py +++ b/benchmarks/dynamo/genai_layers/utils.py @@ -13,7 +13,8 @@ def benchmark_kernel_in_milliseconds(func: Callable, *args, **kwargs) -> float: # warmup for _ in range(5): func(*args, **kwargs) - return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) + with torch.compiler.set_stance("fail_on_recompile"): + return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) @dataclass @@ -41,9 +42,10 @@ def __str__(self): class BenchmarkKernel: - def __init__(self): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): self.name = self.__class__.__name__ self.available_backends: list[str] = [] + self.compile_mode: str = compile_mode # mapping from backend to list of performance results self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list) From 21c97bd565be29ebdea6c690caf2be22f458698f Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 22 Jul 2025 03:49:13 +0000 Subject: [PATCH 377/457] [reland] Transfer "stack_trace" in post_grad passes (#158752) Summary: We transfer stack trace in post_grad passes. We shouldn't add "stack_trace" to _COPY_META_FIELDS because _COPY_META_FIELDS is used in proxy.py where stack_trace is explicitly set. Since the stack_trace is being used by more and more debugging tools, we should also start testing it more rigorously. This PR start by adding a first test for testing that stack trace is preserved through post_grad_passes. Test Plan: ``` buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r test_pattern_matcher_transfer_meta buck run mode/dev-nosan fbcode//caffe2/test/inductor:auto_functionalize -- --rcaffe2/test/inductor:auto_functionalize_old ``` Rollback Plan: Differential Revision: D78669729 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158752 Approved by: https://github.com/jingsh --- test/higher_order_ops/test_invoke_subgraph.py | 8 ++- test/inductor/test_auto_functionalize.py | 16 +++++- test/inductor/test_provenance_tracing.py | 56 +++++++++++++++++++ torch/_inductor/pattern_matcher.py | 2 + 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 72daebb5f4f3f..c800eb78f905a 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -21,6 +21,7 @@ normalize_gm, ) from torch._higher_order_ops.schema import find_hop_schema +from torch._inductor import config as inductor_config from torch._inductor.pattern_matcher import ( CallFunctionVarArgs, PatternMatcherPass, @@ -619,6 +620,7 @@ def fn(x, y): self.assertEqual(ref, res) res.sum().backward() + @inductor_config.patch("fx_graph_cache", False) def test_dropout_checks_joint_graph(self): # `dropout` tests that joint graph passes (not just partitioner) is ran # on the hop graphs. Inductor rng functionalization happens in the joint @@ -675,9 +677,9 @@ def forward(self, primals_0: "f32[8]"): sin: "f32[8]" = torch.ops.aten.sin.default(primals_0) inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None - gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); sin = None mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None @@ -690,6 +692,7 @@ def forward(self, primals_0: "f32[8]"): """, ) + @inductor_config.patch("fx_graph_cache", False) def test_dropout_checks_joint_graph_inference(self): # Checks that joint graph results in inductor seeds for just the inference graph @nested_compile_region @@ -719,9 +722,9 @@ def forward(self, arg0_1: "f32[8]"): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8]"): inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None - gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); gt = sin = None @@ -917,6 +920,7 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): """, ) + @inductor_config.patch("fx_graph_cache", False) def test_view_to_reshape(self): @nested_compile_region def gn(x): diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index c91dde52780ac..65df4912a41cb 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -185,9 +185,15 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes + # Custom comment for test foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None return ()""", # noqa: B950 + ignore_comments=True, + ) + + # stack trace should be in post_grad_graph + self.assertTrue( + "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -328,10 +334,16 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes + # Custom comment for test foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None return ()""", + ignore_comments=True, + ) + + # stack trace should be in post_grad_graph + self.assertTrue( + "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index 5dee7a4114049..2dd9ca44eb687 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -9,12 +9,15 @@ from pathlib import Path import torch +from torch._dynamo.utils import detect_fake_mode from torch._inductor import config from torch._inductor.debug import ( create_mapping_pre_post_grad_nodes, create_node_mapping_kernel_to_post_grad, ) +from torch._inductor.fx_passes.post_grad import post_grad_passes from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.virtualized import V from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.triton_utils import requires_cuda @@ -427,5 +430,58 @@ def test_create_node_mapping(self): ) +class TestProvenanceTracingNodeMeta(TestCase): + def get_node_with_target(self, gm, target): + """ + Return first node in gm with target + """ + return next(iter([node for node in gm.graph.nodes if node.target == target])) + + @requires_cuda # test only works for cuda pattern matcher + def test_pattern_matcher_transfer_meta(self): + """ + Test that stack trace is transfered when node is decomposed in post_grad_passes + """ + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + return x * 3 + + x = torch.randn(8, 10).to("cuda") + example_inputs = (x,) + model = Model().to("cuda") + + # mimic the before_post_grad graph + ep = torch.export.export(model, example_inputs).run_decompositions() + gm = ep.module() + + # Set fake mode for V + fake_inputs = [ + node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" + ] + fake_mode = detect_fake_mode(fake_inputs) + V.set_fake_mode(fake_mode) + + addmm_node = self.get_node_with_target(gm, torch.ops.aten.addmm.default) + stack_trace = addmm_node.meta["stack_trace"] + + post_grad_passes(gm, True) # for this test is_inference doesn't matter + + mm_node = self.get_node_with_target(gm, torch.ops.aten.mm.default) + add_node = self.get_node_with_target(gm, torch.ops.aten.add.Tensor) + + self.assertEqual(add_node.meta["stack_trace"], stack_trace) + self.assertEqual(mm_node.meta["stack_trace"], stack_trace) + + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 78aa947ea7f6a..772ddcced96f6 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -143,6 +143,8 @@ def _transfer_meta( for k, v in old_node.meta.items() if k in torch.fx.proxy._COPY_META_FIELDS ) + if "stack_trace" in old_node.meta: + new_meta["stack_trace"] = old_node.meta["stack_trace"] class Match: From d984143a74e5e726e2be35f6531582aab45bcf4c Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Mon, 21 Jul 2025 14:38:22 -0700 Subject: [PATCH 378/457] [ci][cutlass backend] Add ci for cutlass backend tests (#156626) redo of https://github.com/pytorch/pytorch/pull/156136 Differential Revision: [D77327309](https://our.internmc.facebook.com/intern/diff/D77327309) I want to try land the full version first. If the ci is taking too long, we can revert back to only testing for a few names. ``` -k 'test_max_autotune_cutlass_backend_regular_mm and not test_max_autotune_cutlass_backend_regular_mm_streamk' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156626 Approved by: https://github.com/huydhn, https://github.com/mlazos --- .ci/pytorch/test.sh | 8 +++ .github/pytorch-probot.yml | 1 + .github/workflows/h100-cutlass-backend.yml | 58 +++++++++++++++++++ torch/_inductor/codegen/cuda/cutlass_utils.py | 2 +- torch/_inductor/config.py | 10 ++-- torch/_inductor/utils.py | 5 +- 6 files changed, 76 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/h100-cutlass-backend.yml diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index ad6a48b2528e4..b16557061d11f 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -345,6 +345,12 @@ test_h100_symm_mem() { assert_git_not_dirty } +test_h100_cutlass_backend() { + # cutlass backend tests for H100 + TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_backend -k "not addmm" $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_evt $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running +} + test_lazy_tensor_meta_reference_disabled() { export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 echo "Testing lazy tensor operations without meta reference" @@ -1769,6 +1775,8 @@ elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then test_h100_distributed elif [[ "${TEST_CONFIG}" == "h100-symm-mem" ]]; then test_h100_symm_mem +elif [[ "${TEST_CONFIG}" == h100_cutlass_backend ]]; then + test_h100_cutlass_backend else install_torchvision install_monkeytype diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index ac8cb3df0ffcd..5288aca852931 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -32,6 +32,7 @@ ciflow_push_tags: - ciflow/h100 - ciflow/h100-distributed - ciflow/h100-symm-mem +- ciflow/h100-cutlass-backend retryable_workflows: - pull - trunk diff --git a/.github/workflows/h100-cutlass-backend.yml b/.github/workflows/h100-cutlass-backend.yml new file mode 100644 index 0000000000000..82dc2ae2a3944 --- /dev/null +++ b/.github/workflows/h100-cutlass-backend.yml @@ -0,0 +1,58 @@ +name: Limited CI for CUTLASS backend on H100 + +on: + pull_request: + paths: + - .github/workflows/h100-cutlass-backend.yml + workflow_dispatch: + schedule: + - cron: 22 9 * * * # every 24 hours about 2:22am PDT + push: + tags: + - ciflow/h100-cutlass-backend/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '9.0' + test-matrix: | + { include: [ + { config: "h100_cutlass_backend", shard: 1, num_shards: 1, runner: "linux.aws.h100", owners: ["oncall:pt2"] }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.test-matrix }} + secrets: inherit diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index eb479e477ea20..7ca33ea779cc7 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -128,7 +128,7 @@ def path_join(path0, path1): if tmp_cutlass_full_path not in sys.path: def link_and_append(dst_link, src_path, parent_dir): - if os.path.exists(dst_link): + if os.path.lexists(dst_link): assert os.path.islink(dst_link), ( f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e22f5459be3f2..0fb3237dac32b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1517,11 +1517,11 @@ class cuda: # Path to the CUTLASS repo root directory. # The default path only works under PyTorch local development environment. - cutlass_dir = os.environ.get( - "TORCHINDUCTOR_CUTLASS_DIR", - os.path.abspath( - os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") - ), + cutlass_dir = os.path.realpath( + os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), + ) ) # Configures the maximum number of CUTLASS configs to profile in max_autotune. diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bbacd23612643..22c533a5a03c4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1649,8 +1649,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " - "_inductor.config.cuda.cutlass_dir is set correctly. " - "Skipping CUTLASS backend for now." + "_inductor.config.cuda.cutlass_dir %s is set correctly. " + "Skipping CUTLASS backend for now.", + config.cuda.cutlass_dir, ) return False return res From 3a67bf9c620e8958a1677e68779be08eb34dafa3 Mon Sep 17 00:00:00 2001 From: "Junjie Wang (PyTorch)" Date: Tue, 22 Jul 2025 06:04:56 +0000 Subject: [PATCH 379/457] [PGNCCLx] Bring split and merge for PGNCCLx (#158790) Summary: We added group split in D78300794 and remote_group_merge in D78450094. We first want to upstream this change to PGNCCLx as well so that NCCLx can use this new API and we can continue our c10d clean up in https://github.com/pytorch/pytorch/pull/158488. Test Plan: CI ``` buck test -c hpc_comms.use_ncclx=stable comms/ncclx/pg/tests:test_c10d_ncclx -- test_group_split_and_merge ``` Rollback Plan: Differential Revision: D78521060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158790 Approved by: https://github.com/d4l3k --- torch/csrc/distributed/c10d/Backend.hpp | 4 ++-- torch/csrc/distributed/c10d/ProcessGroupGloo.cpp | 4 ++-- torch/csrc/distributed/c10d/ProcessGroupGloo.hpp | 4 ++-- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 ++-- torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 29a6fddd87907..76ffdd38d264e 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -392,7 +392,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { virtual c10::intrusive_ptr split( const std::vector& ranks, - const c10::intrusive_ptr opts) { + const c10::intrusive_ptr& opts) { TORCH_CHECK( false, "Backend ", @@ -402,7 +402,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { virtual c10::intrusive_ptr merge( const c10::intrusive_ptr& store, - const c10::intrusive_ptr opts, + const c10::intrusive_ptr& opts, const int& rank, const int& size) { TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 045e46f9129c9..895915dcc8403 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -699,7 +699,7 @@ const std::vector& ProcessGroupGloo::groupRanks() const { c10::intrusive_ptr ProcessGroupGloo::split( const std::vector& ranks, - const c10::intrusive_ptr opts) { + const c10::intrusive_ptr& opts) { auto it = std::find(ranks.begin(), ranks.end(), rank_); int groupRank; if (it == ranks.end()) { @@ -728,7 +728,7 @@ c10::intrusive_ptr ProcessGroupGloo::split( c10::intrusive_ptr ProcessGroupGloo::merge( const c10::intrusive_ptr& store, - const c10::intrusive_ptr opts, + const c10::intrusive_ptr& opts, const int& rank, const int& size) { auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 655679489adb5..fd3fd779229d2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -310,11 +310,11 @@ class TORCH_API ProcessGroupGloo : public Backend { c10::intrusive_ptr split( const std::vector& ranks, - const c10::intrusive_ptr opts) override; + const c10::intrusive_ptr& opts) override; c10::intrusive_ptr merge( const c10::intrusive_ptr& store, - const c10::intrusive_ptr opts, + const c10::intrusive_ptr& opts, const int& rank, const int& size) override; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c0c98326690be..ba335dff8c5fd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1256,7 +1256,7 @@ void ProcessGroupNCCL::enableCollectivesTiming() { c10::intrusive_ptr ProcessGroupNCCL::split( const std::vector& ranks, - const c10::intrusive_ptr opts) { + const c10::intrusive_ptr& opts) { auto deviceIdx = guessDeviceId(); TORCH_CHECK( deviceIdx >= 0, @@ -1295,7 +1295,7 @@ c10::intrusive_ptr ProcessGroupNCCL::split( c10::intrusive_ptr ProcessGroupNCCL::merge( const c10::intrusive_ptr& store, - const c10::intrusive_ptr opts, + const c10::intrusive_ptr& opts, const int& rank, const int& size) { auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 9d72207a4b79a..dd35afc155f35 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -961,11 +961,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { c10::intrusive_ptr split( const std::vector& ranks, - const c10::intrusive_ptr opts) override; + const c10::intrusive_ptr& opts) override; c10::intrusive_ptr merge( const c10::intrusive_ptr& store, - const c10::intrusive_ptr opts, + const c10::intrusive_ptr& opts, const int& rank, const int& size) override; From 392fa75411a1f293e891395f005615b257c03eda Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 22 Jul 2025 06:10:35 +0000 Subject: [PATCH 380/457] Change from import trace to import config (#158796) Summary: for this particular instance, we're doing from torch._inductor.config import trace ...trace.provenance_tracking... but for all other call sites, we're doing from torch._inductor import config ... config.trace.provenance_tracking.... Test Plan: CI Rollback Plan: Differential Revision: D78699876 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158796 Approved by: https://github.com/c00w --- torch/fx/passes/graph_transform_observer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 29afd600b5fd9..17929bb63787e 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -31,18 +31,20 @@ def __init__( """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ - from torch._inductor.config import trace + from torch._inductor import config as inductor_config self.gm = gm self.passname = passname self.subsystem = subsystem if log_url is None: - log_url = trace.log_url_for_graph_xform + log_url = inductor_config.trace.log_url_for_graph_xform self.log_url = log_url - self.active = trace.provenance_tracking or self.log_url is not None + self.active = ( + self.log_url is not None or inductor_config.trace.provenance_tracking + ) if self.active: self.erased_nodes: set[str] = set() From 91b69deeb0f67a4b4ad7b36cdbb7d5f805b375a0 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Tue, 22 Jul 2025 08:04:55 +0000 Subject: [PATCH 381/457] [ROCm][CI] update fbgemm_gpu hash used by inductor tests (#158602) fbgemm_gpu build started failing with asmjit errors. Moving to latest tip of fbgemm for inductor tests resolves the build failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158602 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/pytorch/common_utils.sh | 27 +++++++++++++++++++++++++- .github/ci_commit_pins/fbgemm_rocm.txt | 2 +- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 9075fe5fb56f8..e9c7741947cf1 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -204,8 +204,32 @@ function install_torchrec_and_fbgemm() { pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_uninstall fbgemm-gpu-nightly + # Set ROCM_HOME isn't available, use ROCM_PATH if set or /opt/rocm + ROCM_HOME="${ROCM_HOME:-${ROCM_PATH:-/opt/rocm}}" + + # Find rocm_version.h header file for ROCm version extract + rocm_version_h="${ROCM_HOME}/include/rocm-core/rocm_version.h" + if [ ! -f "$rocm_version_h" ]; then + rocm_version_h="${ROCM_HOME}/include/rocm_version.h" + fi + + # Error out if rocm_version.h not found + if [ ! -f "$rocm_version_h" ]; then + echo "Error: rocm_version.h not found in expected locations." >&2 + exit 1 + fi + + # Extract major, minor and patch ROCm version numbers + MAJOR_VERSION=$(grep 'ROCM_VERSION_MAJOR' "$rocm_version_h" | awk '{print $3}') + MINOR_VERSION=$(grep 'ROCM_VERSION_MINOR' "$rocm_version_h" | awk '{print $3}') + PATCH_VERSION=$(grep 'ROCM_VERSION_PATCH' "$rocm_version_h" | awk '{print $3}') + ROCM_INT=$((MAJOR_VERSION * 10000 + MINOR_VERSION * 100 + PATCH_VERSION)) + echo "ROCm version: $ROCM_INT" + export BUILD_ROCM_VERSION="$MAJOR_VERSION.$MINOR_VERSION" + pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm + pushd /tmp local wheel_dir=dist/fbgemm_gpu local found_whl=0 @@ -223,7 +247,7 @@ function install_torchrec_and_fbgemm() { pushd fbgemm/fbgemm_gpu git checkout "${fbgemm_commit}" python setup.py bdist_wheel \ - --package_variant=rocm \ + --build-variant=rocm \ -DHIP_ROOT_DIR="${ROCM_PATH}" \ -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" @@ -240,6 +264,7 @@ function install_torchrec_and_fbgemm() { done rm -rf fbgemm + popd else pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu diff --git a/.github/ci_commit_pins/fbgemm_rocm.txt b/.github/ci_commit_pins/fbgemm_rocm.txt index fa11e10ca6b8e..db140a31f3fa4 100644 --- a/.github/ci_commit_pins/fbgemm_rocm.txt +++ b/.github/ci_commit_pins/fbgemm_rocm.txt @@ -1 +1 @@ -5fb5024118e9bb9decf96c2b0b1a8f0010bf56be +7f1de94a4c2d14f59ad4ca84538c36084ea6b2c8 From 0142d5f4e26b1644de58bb8741e4baa04803d67e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 08:33:08 +0000 Subject: [PATCH 382/457] Revert "Remove is_arvr_mode() from xnnpack.buck.bzl (#158682)" This reverts commit f09a484b8164aaadd57a79354f0ccf47733f365e. Reverted https://github.com/pytorch/pytorch/pull/158682 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158682#issuecomment-3101648365)) --- third_party/xnnpack.buck.bzl | 779 +++++++++++++---------------------- 1 file changed, 294 insertions(+), 485 deletions(-) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index 0f50efc032591..231384bd859ab 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,4 +1,5 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") +load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") load( @@ -141,12 +142,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -162,15 +160,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -211,12 +206,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse2", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse2"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse2"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse2"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse2"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -232,15 +224,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse2"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse2"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -281,12 +270,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_ssse3", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("ssse3"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("ssse3"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("ssse3"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("ssse3"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -302,15 +288,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("ssse3"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("ssse3"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -351,12 +334,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse41", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse41"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse41"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("sse41"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("sse41"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -372,15 +352,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("sse41"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("sse41"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -421,12 +398,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -450,15 +424,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -500,12 +471,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnnigfni", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnnigfni"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnnigfni"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnnigfni"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnnigfni"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -545,15 +513,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vnnigfni"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vnnigfni"), + ), + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -598,12 +563,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnni", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnni"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnni"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vnni"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vnni"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -642,15 +604,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vnni"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vnni"), + ), + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, exported_preprocessor_flags = [ @@ -698,10 +657,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avxvnni", - srcs = select({ - "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": prod_srcs_for_arch_wrapper("avxvnni"), - }), + srcs = prod_srcs_for_arch_wrapper("avxvnni") if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -723,15 +679,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avxvnni"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avxvnni"), + ), + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -771,12 +724,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_f16c", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("f16c"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("f16c"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("f16c"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("f16c"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -800,15 +750,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("f16c"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("f16c"), + ), + ] if not is_arvr_mode() else []), platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -852,12 +799,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fma3", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("fma3"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("fma3"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("fma3"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("fma3"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -884,15 +828,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("fma3"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("fma3"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -948,12 +889,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx2", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx2"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx2"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx2"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx2"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -983,15 +921,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx2"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx2"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1054,12 +989,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512f"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512f"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512f"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512f"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1083,15 +1015,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512f"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512f"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1105,12 +1034,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vbmi", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vbmi"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vbmi"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512vbmi"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512vbmi"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1149,15 +1075,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512vbmi"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512vbmi"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1213,12 +1136,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512skx", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512skx"), - "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512skx"), - }), - }), + "ovr_config//cpu:x86_32": prod_srcs_for_arch_wrapper("avx512skx"), + "ovr_config//cpu:x86_64": prod_srcs_for_arch_wrapper("avx512skx"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1254,15 +1174,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "x86|x86_64|platform009|platform010", - prod_srcs_for_arch_wrapper("avx512skx"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = ([ + ( + "x86|x86_64|platform009|platform010", + prod_srcs_for_arch_wrapper("avx512skx"), + ), + ] if not is_arvr_mode() else []), fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1338,11 +1255,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_armsimd32", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("armsimd32"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("armsimd32"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1363,15 +1277,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("armsimd32"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("armsimd32"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1385,12 +1296,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neon"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neon"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1415,19 +1323,16 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neon"), - ), - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neon"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neon"), + ), + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neon"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1441,26 +1346,20 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon_aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neon_aarch64"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), compiler_flags = [ "-O2", ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neon_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neon_aarch64"), + ), + ] if not is_arvr_mode() else [], labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1475,11 +1374,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fma", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1511,15 +1407,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neonfma"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neonfma"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1533,11 +1426,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neonfma_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1545,15 +1435,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", ], labels = labels, - platform_srcs = select({ - "DEFAULT": [ - ( - "(arm64|aarch64)$", - prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(arm64|aarch64)$", + prod_srcs_for_arch_wrapper("neonfma") + prod_srcs_for_arch_wrapper("neonfma_aarch64"), + ), + ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -1568,12 +1455,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("fp16arith"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("fp16arith"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("fp16arith"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("fp16arith"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1620,19 +1504,16 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("fp16arith"), - ), - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("fp16arith") + prod_srcs_for_arch_wrapper("fp16arith_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("fp16arith"), + ), + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("fp16arith") + prod_srcs_for_arch_wrapper("fp16arith_aarch64"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1646,12 +1527,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1676,19 +1554,16 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonfp16"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonfp16"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonfp16"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonfp16"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1702,12 +1577,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_v8", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonv8"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonv8"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonv8"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonv8"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1746,19 +1618,16 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonv8"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonv8"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonv8"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonv8"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1772,11 +1641,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondot"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondot"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1801,15 +1667,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neondot"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neondot"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1823,11 +1686,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1846,15 +1706,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neondot") + prod_srcs_for_arch_wrapper("neondot_aarch64"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -1868,11 +1725,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondotfp16arith"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neondotfp16arith"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1896,15 +1750,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neondotfp16arith"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neondotfp16arith"), + ), + ] if not is_arvr_mode() else [], labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1919,11 +1770,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_dot_fp16arith_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1943,15 +1791,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neondotfp16arith") + prod_srcs_for_arch_wrapper("neondotfp16arith_aarch64"), + ), + ] if not is_arvr_mode() else [], labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -1966,11 +1811,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16arith", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16arith"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfp16arith"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -1995,15 +1837,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("neonfp16arith"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("neonfp16arith"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -2017,11 +1856,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neon_fp16arith_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -2040,15 +1876,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("neonfp16arith") + prod_srcs_for_arch_wrapper("neonfp16arith_aarch64"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, @@ -2062,12 +1895,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neonfma_i8mm", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma_i8mm"), - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma_i8mm"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("neonfma_i8mm"), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neonfma_i8mm"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -2101,19 +1931,16 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)$", - prod_srcs_for_arch_wrapper("neonfma_i8mm"), - ), - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neonfma_i8mm"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)$", + prod_srcs_for_arch_wrapper("neonfma_i8mm"), + ), + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neonfma_i8mm"), + ), + ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -2128,11 +1955,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_neoni8mm", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neoni8mm"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("neoni8mm"), + }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", apple_sdks = (IOS, MACOSX, APPLETVOS), @@ -2153,15 +1977,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(arm64|aarch64)", - prod_srcs_for_arch_wrapper("neoni8mm"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(arm64|aarch64)", + prod_srcs_for_arch_wrapper("neoni8mm"), + ), + ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -2176,11 +1997,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_asm_aarch32", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("aarch32"), - }), - }), + "ovr_config//cpu:arm32": prod_srcs_for_arch_wrapper("aarch32"), + }) if is_arvr_mode() else [], headers = subdir_glob([ ("XNNPACK/src", "xnnpack/assembly.h"), ("XNNPACK/src", "**/*.S"), @@ -2208,15 +2026,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch32|arm32|armv7)", - prod_srcs_for_arch_wrapper("aarch32"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch32|arm32|armv7)", + prod_srcs_for_arch_wrapper("aarch32"), + ), + ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), fbandroid_link_whole = True, preferred_linkage = "static", @@ -2231,11 +2046,8 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_asm_aarch64", srcs = select({ "DEFAULT": [], - "ovr_config//build_mode:arvr_mode": select({ - "DEFAULT": [], - "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("aarch64"), - }), - }), + "ovr_config//cpu:arm64": prod_srcs_for_arch_wrapper("aarch64"), + }) if is_arvr_mode() else [], headers = subdir_glob([ ("XNNPACK/src", "xnnpack/assembly.h"), ("XNNPACK/src", "**/*.S"), @@ -2259,15 +2071,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = select({ - "DEFAULT": [ - ( - "(aarch64|arm64)", - prod_srcs_for_arch_wrapper("aarch64"), - ), - ], - "ovr_config//build_mode:arvr_mode": [], - }), + platform_srcs = [ + ( + "(aarch64|arm64)", + prod_srcs_for_arch_wrapper("aarch64"), + ), + ] if not is_arvr_mode() else [], fbandroid_link_whole = True, preferred_linkage = "static", preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, From 9b4d938f04c95cebe0fbd96974f64c935567e039 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 22 Jul 2025 06:28:11 +0000 Subject: [PATCH 383/457] [dynamo][fsdp] Consistent behavior of int attributes (#157262) Reimpl of https://github.com/pytorch/pytorch/pull/150954 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157262 Approved by: https://github.com/bdhirsh --- test/distributed/test_dynamo_distributed.py | 82 +++++++++++++++++++++ torch/_dynamo/config.py | 7 ++ torch/_dynamo/utils.py | 9 +++ torch/_guards.py | 11 --- 4 files changed, 98 insertions(+), 11 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d3436bbe47548..86410d8919d21 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -678,6 +678,88 @@ def test_fsdp_aot_eager(self): outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_dynamism_on_int_attr(self): + global GUARDS_FILE + GUARDS_FILE = StringIO() + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + + class ToyModelWithIntAttr(nn.Module): + def __init__(self): + super().__init__() + self.attr = 2 + + def forward(self, x): + out = x + self.attr + + @comptime + def _(ctx): + ctx.print_guards(file=GUARDS_FILE) + + return out + + def get_model_with_int_attr(device): + m = ToyModelWithIntAttr().to(device) + inputs = torch.rand(10).to(device) + outputs = m(inputs) + return m, inputs, outputs + + m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + compiled_fsdp_m = torch.compile( + fsdp_m, backend="eager", dynamic=True, fullgraph=True + ) + outputs = compiled_fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + FileCheck().check( + """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH""" + ).run(GUARDS_FILE.getvalue()) + + @config.patch(enable_compiler_collectives=True) + @config.patch(allow_unspec_int_on_fsdp_module=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_dynamism_on_int_attr_unspec(self): + global GUARDS_FILE + GUARDS_FILE = StringIO() + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + + class ToyModelWithIntAttr(nn.Module): + def __init__(self): + super().__init__() + self.attr = 2 + + def forward(self, x): + out = x + self.attr + + @comptime + def _(ctx): + ctx.print_guards(file=GUARDS_FILE) + + return out + + def get_model_with_int_attr(device): + m = ToyModelWithIntAttr().to(device) + inputs = torch.rand(10).to(device) + outputs = m(inputs) + return m, inputs, outputs + + m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + compiled_fsdp_m = torch.compile( + fsdp_m, backend="eager", dynamic=True, fullgraph=True + ) + outputs = compiled_fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # No presence of EQUALS_MATCH because the guard will be dynamic + FileCheck().check( + """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH""" + ).run(GUARDS_FILE.getvalue()) + @skip_if_lt_x_gpu(2) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_ddp_optimizer_cudagraph(self): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5fd37d5392742..7ef748b85f3e3 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -282,6 +282,13 @@ # Defaults to False for BC. allow_unspec_int_on_nn_module = False +# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions, +# integer attributes on FSDP modules were treated as dynamic, while the same +# attributes on plain nn.Modules were static. We unified the behaviour by making +# FSDP ints static too. Set this flag to True to restore the legacy dynamic +# handling if needed. +allow_unspec_int_on_fsdp_module = False + # Specify how to optimize a compiled DDP module. The flag accepts a boolean # value or a string. There are 3 modes. # 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index f850e3ecb7c31..cf3ed5d135e6a 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2439,6 +2439,15 @@ def is_int_specialization_case(value, source): source.guard_source().is_specialized_nn_module() and not config.allow_unspec_int_on_nn_module ) + # integers coming from FSDP modules are considered static. This is + # purely empirical and perhaps we should have a better heuristic. + or ( + source.guard_source().is_fsdp_module() + and not ( + config.allow_unspec_int_on_nn_module + or config.allow_unspec_int_on_fsdp_module + ) + ) or ( source.guard_source().is_unspecialized_builtin_nn_module() and not config.allow_unspec_int_on_nn_module diff --git a/torch/_guards.py b/torch/_guards.py index 9d64513d01be4..bce574df3feb0 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -155,17 +155,6 @@ def is_fsdp_module(self) -> bool: return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) def is_specialized_nn_module(self) -> bool: - import torch._dynamo.config as config - - if config._unsafe_skip_fsdp_module_guards: - return ( - self - in ( - GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, - GuardSource.LOCAL_SPECIALIZED_NN_MODULE, - ) - or self.is_fsdp_module() - ) return self in ( GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, GuardSource.LOCAL_SPECIALIZED_NN_MODULE, From 8e99714204271a6c60866c10c4360166c24ae68e Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 22 Jul 2025 06:14:12 -0700 Subject: [PATCH 384/457] [EZ][BE][MPS] Remove unused `ndArrayFromTensor` (#158823) And `printTensorNDArray`, both of which according to https://github.com/search?type=code&q=ndArrayFromTensor+org%3Apytorch are not used anywhere Pull Request resolved: https://github.com/pytorch/pytorch/pull/158823 Approved by: https://github.com/dcci ghstack dependencies: #158690 --- aten/src/ATen/native/mps/OperationUtils.h | 2 -- aten/src/ATen/native/mps/OperationUtils.mm | 30 ---------------------- 2 files changed, 32 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 976b62c7ac4b4..e6f87f5499a47 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -145,8 +145,6 @@ MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStre MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); -void printTensorNDArray(const TensorBase& t); -MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index a58b334307f2f..dad70cfc299cf 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -377,36 +377,6 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; } -void printTensorNDArray(const TensorBase& t) { - if (!t.is_mps()) - return; - if (t.numel() == 0) - return; - // Get shape and data type - auto selfShape = getMPSShape(t); - auto selfDType = getMPSDataType(t.scalar_type()); - - // Initialize data - id selfBuf = getMTLBufferStorage(t); - MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape - dataType:selfDType] autorelease]; - C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wobjc-method-access") - C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access") -#endif - [tdata printNDArray]; - C10_CLANG_DIAGNOSTIC_POP() -} - -MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType) { - id buffer = getMTLBufferStorage(tensor); - MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer - shape:shape - dataType:mpsType] autorelease]; - - return [tmpGraphTensorData mpsndarray]; -} - static std::vector getSortedStrides(const IntArrayRef& s) { std::vector idx(s.size()); iota(idx.begin(), idx.end(), 0); From 1b772de3974ee24f7d3ebcb2b35278d6e3356096 Mon Sep 17 00:00:00 2001 From: James Wu Date: Mon, 21 Jul 2025 14:04:17 -0700 Subject: [PATCH 385/457] Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048) When running BundledAOTAutogradCache with precompile, we still need to run triton bundling so that the precompiled CompiledFxGraph has triton cuda kernels. We also pre save the autotune results in the precompile artifact. It would be even better to pre trim the cuda kernels on save and apply them, which we can work on later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158048 Approved by: https://github.com/zhxchen17 --- test/dynamo/test_package.py | 36 +++++++++++++++++++++++ torch/_dynamo/precompile_context.py | 9 ++++-- torch/_inductor/compile_fx.py | 29 +++++++++++++++++- torch/_inductor/runtime/autotune_cache.py | 10 +++++++ 4 files changed, 81 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 3160007774090..d43a8d6c5564c 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -15,10 +15,12 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._functorch import config as functorch_config +from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocm, ) from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU @@ -428,6 +430,40 @@ def fn2(x): self.assertEqual(expected, [result1, result2]) self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + @parametrize("device", ("cuda", "xpu")) + @torch._dynamo.config.patch(caching_precompile=True) + @skipIfRocm + def test_automatic_dynamo_autotune_cache(self, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + def fn(x, y): + return x.sin() + y + + arg1 = torch.randn(3, 3, device=device) + arg2 = torch.randn(3, 3, device=device) + expected = fn(arg1, arg2).clone() + + with PatchCaches(): + compiled_fn1 = torch.compile(fn, mode="max-autotune") + result = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result) + self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) + DynamoCache.clear() + + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER + self._save_and_reload( + expected_backends=1, expected_dynamo=1, expected_autotune=1 + ) + compiled_fn1 = torch.compile(fn, mode="max-autotune") + with torch.compiler.set_stance("fail_on_recompile"): + result1 = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result1) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) + @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) def test_automatic_dynamo_recompiles(self, device): diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 6bb42bb34bc35..040f54ce70db2 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -70,7 +70,8 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact - - CodeStateArtifact (from torch._dynamo.package once available) + - DynamoCodeStateArtifact + - AutotuneCacheArtifact (regular autotune results, same as Megacache) """ # Protected by the compile_lock @@ -149,8 +150,12 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: artifacts_by_key = {} cache_info = CacheInfo() for artifact in chain(*artifacts.values()): + if artifact.type() == "autotune": + # Populate autotune cache artifacts + artifact.populate_cache() + else: + artifacts_by_key[artifact.key] = artifact cache_info.add(artifact) - artifacts_by_key[artifact.key] = artifact from torch._dynamo.package import _BackendId, DynamoCache diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 95c12d12c7850..8e712a28a3b0f 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -909,10 +909,37 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") + if torch._functorch.config.bundled_autograd_cache: + assert mb_compiled_graph is None + assert cache_info is None + # When using bundled autograd cache, we still want + # to use the TritonBundler, but we don't want to save + # the results here. The results will get saved directly + # to AOTAutogradCache. + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + # CACHE BYPASS: Compile the graph, don't save it to the cache # (this can happen either because cache was disabled, or we # determined the input is uncacheable) - if cache_info is None or cache_info["cache_state"] == "bypass": + elif cache_info is None or cache_info["cache_state"] == "bypass": assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 01d038aab8e7b..88b9c80c77146 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,6 +35,7 @@ from typing_extensions import override import torch +from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -125,6 +126,7 @@ def create( ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -300,6 +302,10 @@ def save( CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -625,6 +631,10 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) return result @override From 371ffaf415baf6251b9d98466c8ee970b3556282 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 22 Jul 2025 04:51:46 -0700 Subject: [PATCH 386/457] [bucketing] Support case of several pgs in graph (#158632) Main changes: - bucketing collectives only from the same process_group by group_name - Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632 Approved by: https://github.com/wconstab --- torch/_inductor/fx_passes/bucketing.py | 487 +++++++++++++------------ 1 file changed, 252 insertions(+), 235 deletions(-) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 8f5bb5ffd3248..1794ce3a2a294 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,6 +1,7 @@ import logging import math import operator +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -77,13 +78,9 @@ def bucket_all_gather_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of all_gather nodes. """ - node_list = gm.graph.nodes - # Prerequisite: Check if there is any all_gather node found_all_gather = False for node in node_list: @@ -92,48 +89,53 @@ def bucket_all_gather_by_mb( break if not found_all_gather: return [] - - ag_nodes: list[torch.fx.Node] = [] - + group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all all_gather nodes for node in node_list: if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): if (filter_wait_node is None) or filter_wait_node(node): ag_node = node.args[0] - ag_nodes.append(ag_node) - + _, group_size, group_name = ag_node.args + dtype = ag_node.meta["val"].dtype + assert isinstance(group_name, str) + group_name_ag_nodes[(group_name, dtype)].append(ag_node) # Step 2: Put all_gather nodes into buckets ag_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - all_gather_bucket_size_bytes = int( - all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for ag_node in ag_nodes: - assert is_all_gather_into_tensor(ag_node) - assert "val" in ag_node.meta - ag_output_size_bytes = ( - ag_node.meta["val"].numel() - * torch.finfo(ag_node.meta["val"].dtype).bits - // 8 + for (group_name, dtype), ag_nodes in group_name_ag_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + all_gather_bucket_size_bytes = int( + all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket + for ag_node in ag_nodes: + assert is_all_gather_into_tensor(ag_node) + if ag_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + assert "val" in ag_node.meta + ag_n_val = ag_node.meta["val"] + ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size() + if ( + cur_bucket_size_bytes + ag_output_size_bytes + > all_gather_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + if len(cur_bucket) > 1: + ag_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_size_bytes += ag_output_size_bytes + cur_bucket.append(ag_node) + find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users) + if len(cur_bucket) > 1: + # add remaining nodes in the last bucket ag_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - cur_bucket_size_bytes += ag_output_size_bytes - cur_bucket.append(ag_node) - if cur_bucket: - # add remaining nodes in the last bucket - ag_buckets.append(cur_bucket) - return ag_buckets @@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of reduce_scatter nodes. """ - node_list = list(gm.graph.nodes) - # Prerequisite: Check if there is any reduce_scatter node found_reduce_scatter = False for node in node_list: @@ -158,64 +156,71 @@ def bucket_reduce_scatter_by_mb( break if not found_reduce_scatter: return [] - - rs_nodes: list[torch.fx.Node] = [] - + group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all reduce_scatter nodes for node in node_list: if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]): rs_node = node.args[0] - rs_nodes.append(rs_node) - + _, reduce_op, group_size, group_name = rs_node.args + dtype = rs_node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node) # Step 2: Put reduce_scatter nodes into buckets rs_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for rs_node in rs_nodes: - assert is_reduce_scatter_tensor(rs_node) - rs_input = rs_node.args[0] - assert "val" in rs_input.meta # type: ignore[union-attr] - rs_input_size_bytes = ( - rs_input.meta["val"].numel() # type: ignore[union-attr] - * torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr] - // 8 + for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + # Convert MiB to bytes + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + rs_input_size_bytes - > reduce_scatter_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket - total_size = cur_bucket_size_bytes + rs_input_size_bytes + for rs_node in rs_nodes: + assert is_reduce_scatter_tensor(rs_node) + if rs_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + rs_input = rs_node.args[0] + assert "val" in rs_input.meta # type: ignore[union-attr] + rs_in_val = rs_input.meta["val"] # type: ignore[union-attr] + rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size() + if ( + cur_bucket_size_bytes + rs_input_size_bytes + > reduce_scatter_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + total_size = cur_bucket_size_bytes + rs_input_size_bytes + logger.info( + f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 + f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " + f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"bucket_cap = {reduce_scatter_bucket_size_bytes}" + ) + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 + ) + cur_bucket_size_bytes += rs_input_size_bytes + cur_bucket.append(rs_node) + find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users) + if cur_bucket: + # add remaining nodes in the last bucket logger.info( - f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 - f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " - f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 + f"total_size = {cur_bucket_size_bytes}, " f"bucket_cap = {reduce_scatter_bucket_size_bytes}" ) - rs_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - cur_bucket_size_bytes += rs_input_size_bytes - cur_bucket.append(rs_node) - if cur_bucket: - # add remaining nodes in the last bucket - logger.info( - f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 - f"total_size = {cur_bucket_size_bytes}, " - f"bucket_cap = {reduce_scatter_bucket_size_bytes}" - ) - rs_buckets.append(cur_bucket) - + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) return rs_buckets @@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def] return env[x] +def _rank_idx_dict(group_name: str) -> dict[int, int]: + from torch.distributed.distributed_c10d import ( + _resolve_process_group, + get_process_group_ranks, + ) + + pg = _resolve_process_group(group_name) + ranks = get_process_group_ranks(pg) + rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)} + return rank_idx_dict + + def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]] ) -> None: @@ -297,15 +314,13 @@ def merge_all_gather( bucket_id_is_scheduled = {} cast_bucket_id_is_scheduled = {} _, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args + + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} + for bucket_id, ag_bucket in enumerate(ag_buckets): ag_input_nodes = [] wait_nodes = [] for ag_node in ag_bucket: - assert ( - ag_node in ag_node_to_wait_node - and ag_node.args[1] == group_size - and ag_node.args[2] == group_name - ) ag_input_nodes.append(ag_node.args[0]) wait_nodes.append(ag_node_to_wait_node[ag_node]) bucket_id_to_bucketed_op_info[bucket_id] = ( @@ -314,6 +329,8 @@ def merge_all_gather( group_name, wait_nodes, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] ag_wait_nodes = list(ag_node_to_wait_node.values()) ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes) @@ -334,9 +351,6 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) - # device = ag_input_nodes[0].meta["val"].device - # rank = device.index - # dtype = ag_input_nodes[0].meta["val"].dtype if all( n.op == "call_function" # type: ignore[union-attr] and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] @@ -398,6 +412,7 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr] rank = device.index dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr] @@ -468,7 +483,7 @@ def merge_all_gather( all_gather_output, inp_split_sizes, all_gather_input_numel, - rank, + rank_idx_dict[rank], ), {}, ) @@ -585,6 +600,7 @@ def merge_reduce_scatter( # Prepare bucketed operation info bucket_id_to_bucketed_op_info = {} bucket_id_is_scheduled = {} + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} for bucket_id, rs_bucket in enumerate(rs_buckets): _, reduce_op, group_size, group_name = next( iter(rs_node_to_wait_node.keys()) @@ -612,6 +628,8 @@ def merge_reduce_scatter( wait_nodes, wait_node_recursive_users, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] new_graph: torch.fx.Graph = torch.fx.Graph() env: dict[torch.fx.Node, torch.fx.Node] = {} @@ -624,155 +642,154 @@ def merge_reduce_scatter( elif node in rs_node_to_wait_node: assert node in rs_node_to_bucket_id bucket_id = rs_node_to_bucket_id[node] - if ( + if not ( bucket_id not in bucket_id_is_scheduled and rs_buckets[bucket_id][-1] == node ): - # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node - ( - rs_input_nodes, - reduce_op, - group_size, - group_name, - orig_wait_nodes, - orig_wait_node_recursive_users, - ) = bucket_id_to_bucketed_op_info[bucket_id] - # parents of rs have been scheduled, so we can directly use the env - unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] - reduce_dtype = unsharded_grads[0].meta["val"].dtype - # Only float32 and bfloat16 are supported for now. - # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. - assert reduce_dtype in ( - torch.float32, - torch.bfloat16, - ), f"reduce_dtype {reduce_dtype} is not supported" - assert all( - grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads - ) - device = unsharded_grads[0].meta["val"].device - rank = device.index - shard_dim = 0 + continue - def _get_dim0_padded_size( - tensor_size: torch.Size, dim0_factor: int - ) -> torch.Size: - padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor - return torch.Size([padded_dim0]) + tensor_size[1:] + # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node + ( + rs_input_nodes, + reduce_op, + group_size, + group_name, + orig_wait_nodes, + orig_wait_node_recursive_users, + ) = bucket_id_to_bucketed_op_info[bucket_id] + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] + # parents of rs have been scheduled, so we can directly use the env + unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] + reduce_dtype = unsharded_grads[0].meta["val"].dtype + # Only float32 and bfloat16 are supported for now. + # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. + assert reduce_dtype in ( + torch.float32, # type: ignore[attr-defined] + torch.bfloat16, # type: ignore[attr-defined] + ), f"reduce_dtype {reduce_dtype} is not supported" + assert all( + grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads + ) + device = unsharded_grads[0].meta["val"].device + rank = device.index + rank_idx = rank_idx_dict[rank] + shard_dim = 0 + + def _get_dim0_padded_size( + tensor_size: torch.Size, + dim0_factor: int, # type: ignore[name-defined] + ) -> torch.Size: # type: ignore[name-defined] + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined] + return torch.Size([padded_dim0]) + tensor_size[1:] + + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] + for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + + """ + NOTE: the relationship between the next few nodes is tricky: + - reduce_scatter_input_reshaped is a view of reduce_scatter_input + (same storage, same # elems, different shape). + - chunk_cat writes into reduce_scatter_input_reshaped, + which indirectly writes into reduce_scatter_input + (since they share the same storage). + - reduce_scatter_tensor reads from reduce_scatter_input. + """ + reduce_scatter_input = new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + ([reduce_scatter_input_numel],), + { + "dtype": reduce_dtype, + "device": device, + "pin_memory": False, + }, + ) + reduce_scatter_input_reshaped = new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (reduce_scatter_input, [group_size, -1]), + {}, + ) + new_graph_call_function( + new_graph, + torch.ops.fsdp.chunk_cat.default, + (unsharded_grads,), + { + "dim": 0, + "num_chunks": group_size, + "out": reduce_scatter_input_reshaped, + }, + ) + reduce_scatter_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + (reduce_scatter_input, reduce_op, group_size, group_name), + {}, + ) - padded_unsharded_sizes = tuple( - _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] - for grad in unsharded_grads - ) - reduce_scatter_input_numel = sum( - s.numel() for s in padded_unsharded_sizes - ) + wait_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.wait_tensor.default, + (reduce_scatter_tensor,), + {}, + ) - """ - NOTE: the relationship between the next few nodes is tricky: - - reduce_scatter_input_reshaped is a view of reduce_scatter_input - (same storage, same # elems, different shape). - - chunk_cat writes into reduce_scatter_input_reshaped, - which indirectly writes into reduce_scatter_input - (since they share the same storage). - - reduce_scatter_tensor reads from reduce_scatter_input. - """ - reduce_scatter_input = new_graph_call_function( - new_graph, - torch.ops.aten.empty.memory_format, - ([reduce_scatter_input_numel],), - { - "dtype": reduce_dtype, - "device": device, - "pin_memory": False, - }, + def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int + ) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + reduce_output = wait_tensor + # View out and accumulate sharded gradients + new_sharded_grads = [] + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, unsharded_grad in zip( + padded_unsharded_sizes, unsharded_grads + ): + # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here + chunks = _chunk_with_empty( + torch.empty_like(unsharded_grad.meta["val"], device="meta"), + group_size, # type: ignore[arg-type] + dim=shard_dim, ) - reduce_scatter_input_reshaped = new_graph_call_function( - new_graph, - torch.ops.aten.reshape.default, - (reduce_scatter_input, [group_size, -1]), - {}, + sharded_param = chunks[rank_idx] + sharded_size = sharded_param.size() + contiguous_sharded_stride = ( + torch._prims_common.make_contiguous_strides_for(sharded_size) ) - new_graph_call_function( + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = new_graph_call_function( new_graph, - torch.ops.fsdp.chunk_cat.default, - (unsharded_grads,), + torch.ops.aten.as_strided.default, + (reduce_output,), { - "dim": 0, - "num_chunks": group_size, - "out": reduce_scatter_input_reshaped, + "size": sharded_size, + "stride": contiguous_sharded_stride, + "storage_offset": flat_grad_offset, }, ) - reduce_scatter_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - (reduce_scatter_input, reduce_op, group_size, group_name), - {}, - ) - - wait_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.wait_tensor.default, - (reduce_scatter_tensor,), - {}, - ) - - def _chunk_with_empty( - tensor: torch.Tensor, num_chunks: int, dim: int - ) -> list[torch.Tensor]: - chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) - while len(chunks) < num_chunks: - chunks.append(chunks[0].new_empty(0)) - return chunks - - reduce_output = wait_tensor - # View out and accumulate sharded gradients - new_sharded_grads = [] - flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] - for padded_unsharded_size, unsharded_grad in zip( - padded_unsharded_sizes, unsharded_grads - ): - # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here - chunks = _chunk_with_empty( - torch.empty_like(unsharded_grad.meta["val"], device="meta"), - group_size, # type: ignore[arg-type] - dim=shard_dim, - ) - sharded_param = chunks[rank] - sharded_size = sharded_param.size() - contiguous_sharded_stride = ( - torch._prims_common.make_contiguous_strides_for(sharded_size) - ) - # Assume even sharding for Shard(i), i > 0; otherwise would require - # copy-out for contiguous strides - new_sharded_grad = new_graph_call_function( - new_graph, - torch.ops.aten.as_strided.default, - (reduce_output,), - { - "size": sharded_size, - "stride": contiguous_sharded_stride, - "storage_offset": flat_grad_offset, - }, - ) - new_sharded_grads.append(new_sharded_grad) - padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] - flat_grad_offset += padded_sharded_numel # type: ignore[assignment] - assert len(orig_wait_nodes) == len(new_sharded_grads) - assert len(orig_wait_nodes) > 0 - for new_sharded_grad, orig_wait_node in zip( - new_sharded_grads, orig_wait_nodes - ): - env[orig_wait_node] = new_sharded_grad # noqa: PERF403 - for user in sorted( - orig_wait_node_recursive_users, key=lambda x: order[x] - ): - # We skip output node here, because output node will be inserted (later) - # as the last node in the new graph. - if user.op != "output": - node_copy( - env, new_graph, user, lambda x: env_lookup(env, x, user) - ) - bucket_id_is_scheduled[bucket_id] = True + new_sharded_grads.append(new_sharded_grad) + padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] + flat_grad_offset += padded_sharded_numel # type: ignore[assignment] + assert len(orig_wait_nodes) == len(new_sharded_grads) + assert len(orig_wait_nodes) > 0 + for new_sharded_grad, orig_wait_node in zip( + new_sharded_grads, orig_wait_nodes + ): + env[orig_wait_node] = new_sharded_grad # noqa: PERF403 + for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]): + # We skip output node here, because output node will be inserted (later) + # as the last node in the new graph. + if user.op != "output": + node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user)) + bucket_id_is_scheduled[bucket_id] = True else: continue assert node_list[-1].op == "output" From d0c00d9a69df296cdcc659e6e25b1bdc0ac5317c Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 22 Jul 2025 07:36:01 -0700 Subject: [PATCH 387/457] [MPS] Do not crash if tensor dim > INT_MAX (#158824) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Looks like all MPS operations will crash if one of tensor dimentions are greater than `2**31-1` Change it into a structured exception, by checking tensor size before attempting to create MPS Tensor Add regression test for it. Before this change running following will abort with exception ``` % python3 -c "import torch; torch.randint(0, 10, (2**31,), dtype=torch.uint8, device='mps')" /AppleInternal/Library/BuildRoots/1c8f7852-1ca9-11f0-b28b-226177e5bb69/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:829: failed assertion `[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' zsh: abort python3 -c· ``` Skip the test on MacOS-13, as it crashes somewhere deep in MPSGraph framework with ``` /AppleInternal/Library/BuildRoots/c651a45f-806e-11ed-a221-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:724: failed assertion `[MPSTemporaryNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158824 Approved by: https://github.com/dcci ghstack dependencies: #158690, #158823 --- aten/src/ATen/native/mps/OperationUtils.mm | 19 ++++++++++++++++--- test/test_mps.py | 12 ++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index dad70cfc299cf..29d07a3f7aa3d 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -427,12 +427,22 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return result; } +// Should be called before initWithBuffer to prevent hard crashes with +// '[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' +static void check_mps_shape(MPSShape* shape) { + for (NSNumber* elem in shape) { + const auto val = [elem longValue]; + TORCH_CHECK(val <= std::numeric_limits::max(), "MPSGaph does not support tensor dims larger than INT_MAX"); + } +} + MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) { id srcBuf = getMTLBufferStorage(t); MPSDataType mpsDataType = getMPSDataType(t.scalar_type()); MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes]; srcTensorDesc.preferPackedRows = YES; + check_mps_shape(sizes); MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf offset:t.storage_offset() * t.element_size() descriptor:srcTensorDesc] autorelease]; @@ -542,9 +552,9 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { // Tensor is contiguous and has no storage offset. // Wrap it directly inside MPSGraphTensorData if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) { - _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf - shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor) - dataType:dataType] autorelease]; + auto shape = mpsShape_ ? mpsShape_ : getMPSShape(_tensor); + check_mps_shape(shape); + _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:shape dataType:dataType] autorelease]; } else { IntArrayRef view_shape; if (mpsShape_) { @@ -553,8 +563,11 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { MPSShape* mpsShape = getMPSShape(_tensor); MPSShape* mpsStrides = getMPSShape(_tensor.strides()); + check_mps_shape(mpsShape); auto storage_numel = src.storage().nbytes() / src.element_size(); + TORCH_CHECK(storage_numel <= std::numeric_limits::max(), + "MPSGaph does not support tensor dims larger than INT_MAX"); MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:@[ @(storage_numel) ]]; srcTensorDesc.preferPackedRows = YES; diff --git a/test/test_mps.py b/test/test_mps.py index d9e4b7a9f037c..ea1013c972135 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8013,6 +8013,18 @@ def test_64bit_index_select(self): gc.collect() torch.mps.empty_cache() + @serialTest() + def test_rand_2b_raises(self): + if MACOS_VERSION < 14.0: + raise unittest.SkipTest("Crashes on MacOS-13") + int32_max = torch.iinfo(torch.int32).max + with self.assertRaises(RuntimeError): + # This used to crash with NDArray dimension length > INT_MAX + x = torch.randint(0, 10, (int32_max + 1,), dtype=torch.int8, device='mps') + x = torch.randint(0, 10, (int32_max,), dtype=torch.int8, device='mps') + self.assertEqual(x.numel(), int32_max) + del x + class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): From 9a28e23d9792551d5a070cec8c67d0e499358825 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 15:45:34 +0000 Subject: [PATCH 388/457] Revert "removed zero dim cpu logic from fake_tensor.py (#147501)" This reverts commit 9e0473b56621162bd85e94943a516be4727e5651. Reverted https://github.com/pytorch/pytorch/pull/147501 on behalf of https://github.com/ZainRizvi due to Seems to have broken ROCm. See inductor/test_aot_inductor_package.py::TestAOTInductorPackageCpp_cuda::test_compile_standalone_cos [GH job link](https://github.com/pytorch/pytorch/actions/runs/16428359564/job/46426243808) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/a991e285ae35159680b0ad4be24669906a6fa256) ([comment](https://github.com/pytorch/pytorch/pull/147501#issuecomment-3103494041)) --- test/test_fake_tensor.py | 16 ---------------- torch/_subclasses/fake_tensor.py | 13 ++----------- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 7aa530ae3296b..e8c28cadbf829 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -211,22 +211,6 @@ def test_zero_dim(self): self.assertEqual(out.device, y.device) self.assertTrue(isinstance(out, FakeTensor)) - @unittest.skipIf(not RUN_CUDA, "requires cuda") - def test_op_with_zero_dim_bypassed(self): - if torch._functorch.config.fake_tensor_propagate_real_tensors: - return - shape_env = ShapeEnv() - mode = FakeTensorMode(shape_env=shape_env) - x = torch.tensor(1.0, device="cuda") - y = torch.tensor(2.0) - fake_x = mode.from_tensor(x) - fake_y = mode.from_tensor(y) - - with self.assertRaisesRegex( - RuntimeError, "Unhandled FakeTensor Device Propagation for.*" - ) as exc: - torch.nextafter(fake_x, fake_y) - def test_nan_to_num(self): with FakeTensorMode(): for dtype in [torch.float16, torch.float32]: diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 77cf89e9186b6..c17de15f46eac 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -889,11 +889,6 @@ def _find_common_device( aten._foreach_copy.default, ) - # list of ops not using zero dim cpu tensor logic to align with the eager mode. - bypass_zero_dim_cpu_tensor_check_ops = ordered_set( - aten.nextafter.default, - ) - def check_cpu_device(device: torch.device) -> bool: return device.type == "cpu" @@ -917,17 +912,13 @@ def merge_devices(t: object) -> None: is_cpu_zero_dim = t_is_cpu_zero_dim return - is_bypass_zero_dim_cpu_tensor_check_op = ( - func in bypass_zero_dim_cpu_tensor_check_ops - ) - # mismatching devices ! # if current tensor is cpu 0 dim, defer to existing device - if t_is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: + if t_is_cpu_zero_dim: return # current device is from cpu 0 dim tensor, overwrite - if is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: + if is_cpu_zero_dim: common_device = t.device is_cpu_zero_dim = t_is_cpu_zero_dim return From 4060f3004264dc4414239cdc3145b7e46fa3729f Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 21 Jul 2025 22:50:48 +0000 Subject: [PATCH 389/457] [AOTI] Convert C-struct zip handling to RAII container (#158687) Attempts to fix a memory leak reported in #158614 by wrapping manually managed MiniZ C-structs in an RAII container. I have been unable to reproduce the reported leak, but this seems like the most likely candidate. Fixes #158614 (hopefully) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158687 Approved by: https://github.com/desertfire --- test/cpp/aoti_inference/test.cpp | 2 + .../aoti_package/model_package_loader.cpp | 102 +++++++++++------- 2 files changed, 68 insertions(+), 36 deletions(-) diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 59d575b2cc2bb..bff3827f8e8ac 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -144,6 +144,8 @@ void test_aoti_package_loader_multi_gpu( const std::string& device, bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; + // Ensure that this test will reset the default CUDA device on exit. + torch::DeviceGuard device_guard(c10::Device("cuda")); std::string data_path = (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt") diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 4018d9b00a75e..ed4d302bb7b34 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -443,6 +443,69 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { } } +class RAIIMinizArchive { + public: + RAIIMinizArchive(const std::string& zip_path) { + mz_zip_zero_struct(&_zip_archive); + if (!mz_zip_reader_init_file(&_zip_archive, zip_path.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to initialize zip archive: {}", + mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)))); + } + } + RAIIMinizArchive(const RAIIMinizArchive&) = delete; + RAIIMinizArchive& operator=(const RAIIMinizArchive&) = delete; + RAIIMinizArchive(RAIIMinizArchive&&) noexcept = delete; + RAIIMinizArchive& operator=(RAIIMinizArchive&&) noexcept = delete; + ~RAIIMinizArchive() { + // Unconditionally close the file. We can't handle any errors here without + // terminating the program. + mz_zip_reader_end(&_zip_archive); + } + + std::vector get_filenames() { + const unsigned num_zip_files{mz_zip_reader_get_num_files(&_zip_archive)}; + std::vector zip_filenames{}; + zip_filenames.reserve(num_zip_files); + + for (unsigned i{0}; i < num_zip_files; ++i) { + // filename_buf_size == 0 returns the filename length, including null + // terminator + const auto zip_filename_len{ + mz_zip_reader_get_filename(&_zip_archive, i, nullptr, 0)}; + if (!zip_filename_len) { + throw std::runtime_error( + fmt::format("Failed to read zip filename length at index {}", i)); + } + // std::string implicitly appends a character for the null terminator + std::string zip_filename(zip_filename_len - 1, '\0'); + if (!mz_zip_reader_get_filename( + &_zip_archive, i, zip_filename.data(), zip_filename_len)) { + throw std::runtime_error( + fmt::format("Failed to read zip filename at index {}", i)); + } + zip_filenames.emplace_back(zip_filename); + } + + return zip_filenames; + } + + void extract_file( + const std::string& zip_filename, + const std::string& dest_filename) { + if (!mz_zip_reader_extract_file_to_file( + &_zip_archive, zip_filename.c_str(), dest_filename.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to extract zip file {} to destination file {}", + zip_filename, + dest_filename)); + } + } + + private: + mz_zip_archive _zip_archive{}; +}; + AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path, const std::string& model_name, @@ -462,32 +525,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extract all files within the zipfile to a temporary directory - mz_zip_archive zip_archive; - memset(&zip_archive, 0, sizeof(zip_archive)); - - if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { - throw std::runtime_error( - std::string("Failed to initialize zip archive: ") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - - std::vector found_filenames; - for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { - uint32_t filename_len = - mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); - if (filename_len == 0) { - throw std::runtime_error("Failed to read filename"); - } - // filename_len returned by mz_zip_reader_get_filename includes the null - // terminator, so we need to subtract 1 here - std::string filename_str(filename_len - 1, '\0'); - if (!mz_zip_reader_get_filename( - &zip_archive, i, filename_str.data(), filename_len)) { - throw std::runtime_error("Failed to read filename"); - } - found_filenames.push_back(normalize_path_separator(filename_str)); - } - + RAIIMinizArchive zip_archive{model_package_path}; + auto found_filenames{zip_archive.get_filenames()}; if (found_filenames.empty()) { throw std::runtime_error("No files found in zip archive."); } @@ -560,8 +599,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - mz_zip_reader_extract_file_to_file( - &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); + zip_archive.extract_file(filename_str, output_path_str); // Save the file for bookkeeping size_t extension_idx = output_path_str.find_last_of('.'); @@ -578,14 +616,6 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } } - // Close the zip archive as we have extracted all files to the temp - // directory - if (!mz_zip_reader_end(&zip_archive)) { - throw std::runtime_error( - std::string("Failed to close zip archive: {}") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - if (cpp_filename.empty() && so_filename.empty()) { std::string found_filenames_str; for (const std::string& filename : found_filenames) { From 7d6f3402380de06cee9d10f708e373e36aa9bd9c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 16:20:17 +0000 Subject: [PATCH 390/457] Revert "[AOTI] Add more default options to compile_standalone (#158560)" This reverts commit a991e285ae35159680b0ad4be24669906a6fa256. Reverted https://github.com/pytorch/pytorch/pull/158560 on behalf of https://github.com/jeffdaily due to broke rocm CI, no test signal was available from rocm ciflow/trunk, need to add ciflow/rocm to reland ([comment](https://github.com/pytorch/pytorch/pull/158560#issuecomment-3103633964)) --- test/inductor/test_aot_inductor.py | 10 +-- test/inductor/test_aot_inductor_package.py | 36 --------- torch/_inductor/codecache.py | 21 +++-- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 +++-- torch/_inductor/codegen/triton.py | 5 -- torch/_inductor/config.py | 8 +- torch/_inductor/cpp_builder.py | 93 +++++----------------- torch/_inductor/utils.py | 40 ++++------ 8 files changed, 61 insertions(+), 170 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index c8281e1b505a0..f71c27d92cf82 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6662,19 +6662,11 @@ def test_compile_standalone_sets_package_cpp(self): result = maybe_aoti_standalone_config({"aot_inductor.compile_standalone": True}) self.assertEqual(result["aot_inductor.package_cpp_only"], True) self.assertEqual(result["aot_inductor.compile_standalone"], True) - self.assertEqual(result["aot_inductor.embed_kernel_binary"], True) - self.assertEqual(result["aot_inductor.emit_multi_arch_kernel"], True) - self.assertEqual( - result["aot_inductor.model_name_for_generated_files"], "aoti_model" - ) - def test_compile_standalone_explicit_set(self): + def test_compile_standalone_package_cpp_already_true(self): patches = { "aot_inductor.compile_standalone": True, "aot_inductor.package_cpp_only": True, - "aot_inductor.embed_kernel_binary": True, - "aot_inductor.emit_multi_arch_kernel": True, - "aot_inductor.model_name_for_generated_files": "aoti_model", } result = maybe_aoti_standalone_config(patches) self.assertEqual(result, patches) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2809f5533bd9c..51343b6b1883e 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -15,7 +15,6 @@ from parameterized import parameterized_class import torch -import torch._inductor.config from torch._inductor.codecache import get_kernel_bin_format from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase @@ -364,7 +363,6 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different - @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_after_package_static(self): # compile_standalone will set package_cpp_only=True self.check_package_cpp_only() @@ -421,46 +419,12 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") - @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfXpu # build system may be different - @torch._inductor.config.patch("test_configs.use_libtorch", True) - def test_compile_standalone_cos(self): - # compile_standalone will set package_cpp_only=True - self.check_package_cpp_only() - - class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x): - return torch.cos(x) - - with torch.no_grad(): - example_inputs = (torch.randn(8, 32, device=self.device),) - model = Model().to(device=self.device) - - # Test compilation when model name is passed in - options = { - "aot_inductor.compile_standalone": True, - "aot_inductor.model_name_for_generated_files": "cos", - } - with ( - tempfile.TemporaryDirectory() as tmp_dir, - ): - build_path, _ = self.cmake_compile( - model, example_inputs, options, tmp_dir - ) - # Check if the .a file was build successfully - a_path = build_path / "libcos.a" - self.assertTrue(a_path.exists()) - @unittest.skipIf( _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary - @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter(self): self.check_package_cpp_only() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index dd5a591f421f3..c8b23aded15c2 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1674,6 +1674,12 @@ def compile( wrapper_code = "\n".join((wrapper_code, kernel_code)) kernel_code = "" + from .utils import aoti_model_name_from_config + + model_class_name = "" + if config.aot_inductor.compile_standalone: + model_class_name = aoti_model_name_from_config() + wrapper_key, wrapper_path = write( wrapper_code, "wrapper.cpp", @@ -1706,8 +1712,6 @@ def compile( "model.h", ) ) as f: - # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone - model_class_name = config.aot_inductor.model_name_for_generated_files class_name = f"AOTInductorModel{model_class_name}" header_code = f.read() @@ -1722,7 +1726,7 @@ def compile( header_code, "h", specified_dir=specified_output_path, - key=model_class_name, + key=f"{model_class_name}", ) # Log the AOTInductor wrapper and kernel code, if needed. @@ -1836,7 +1840,7 @@ def format_consts_to_asm( consts_asm += f"\t.space {len(consts) - 8}\n" consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" - return consts_asm, "weights.S" + return consts_asm, "S" # Use c++ to convert consts to object file can support more compilers, such as msvc and icx. def format_consts_to_cpp( @@ -1861,7 +1865,7 @@ def format_consts_to_cpp( const_cpp += "\t\n" const_cpp += "};\t\n" const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" - return const_cpp, "weights.cpp" + return const_cpp, "cpp" if use_asm_build: consts_code, code_ext = format_consts_to_asm( @@ -1876,7 +1880,6 @@ def format_consts_to_cpp( consts_code, code_ext, specified_dir=str(specified_sub_dir), - key=config.aot_inductor.model_name_for_generated_files, ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( @@ -2170,13 +2173,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: asm_files = [] if not _IS_WINDOWS: ld, objcopy = get_ld_and_objcopy(use_relative_path) - kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): - if kernel_name not in kernels: - # It is possible that CudaKernelParamCache contains more Triton kernels - # than what the current graph uses - continue - if asm_file := value["asm"]: asm_files.append(asm_file) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9abdcce44f6c9..56d6f40dade81 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,7 +22,13 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, ir -from ..utils import _align, DeferredLineBase, LineContext, normalize_name +from ..utils import ( + _align, + aoti_model_name_from_config, + DeferredLineBase, + LineContext, + normalize_name, +) from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides, IndentedBuffer, Kernel @@ -58,15 +64,11 @@ def __init__(self): self.device = "cpu" # must be initialized prior to calling super().__init__() self.included_devices: OrderedSet[str] = OrderedSet() - self.model_class_name_suffix = ( - config.aot_inductor.model_name_for_generated_files - if config.aot_inductor.compile_standalone - else "" - ) + self.model_class_name_suffix = "" + if config.aot_inductor.compile_standalone: + self.model_class_name_suffix = aoti_model_name_from_config() self.aoti_model_class_name = f"AOTInductorModel{self.model_class_name_suffix}" - super().__init__() - self.declare = "auto " self.declare_maybe_reference = "decltype(auto) " self.ending = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4aaff61e77d4f..a34665d720f47 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4479,11 +4479,6 @@ def define_kernel(self, src_code, node_schedule, kernel): kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] ) - if config.aot_inductor.model_name_for_generated_files: - # When AOTI compiles multiple submodules, we need to use the model name to - # distinguish kernel related symbols. - kernel_name = f"{config.aot_inductor.model_name_for_generated_files}_{kernel_name}" - # use the original src_code as the key wrapper.src_to_kernel[src_code] = kernel_name subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 0fb3237dac32b..ae2ee6a574c73 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1458,12 +1458,12 @@ class aot_inductor: precompile_headers: bool = not is_fbcode() # Embed generated kernel binary files into model.so - embed_kernel_binary: Optional[bool] = None + embed_kernel_binary: bool = False # Generate kernel files that support multiple archs # For CUDA, this means generating fatbin files for kernels, and the fatbin files # contains PTX and SASS for the current architecture. - emit_multi_arch_kernel: Optional[bool] = None + emit_multi_arch_kernel: bool = False # If not None, the generated files with use this name in file stem. # If None, we will use a hash to name files. @@ -1850,10 +1850,6 @@ class test_configs: graphsafe_rng_func_ignores_fallback_random = False - # If set to True, AOTI-generated CMakelists.txt will still use libtorch - # for unit testing - use_libtorch = False - if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 64140542d9ba0..47820d3d77250 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -28,6 +28,7 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.utils import aoti_model_name_from_config from torch.torch_version import TorchVersion @@ -1544,9 +1545,7 @@ def __init__( self._aot_mode: bool = False self._name = name - self._target_name = ( - config.aot_inductor.model_name_for_generated_files or "aoti_model" - ) + self._target_name = aoti_model_name_from_config() # Code start here, initial self internal variables firstly. self._build_option = BuildOption @@ -1772,13 +1771,9 @@ def save_compile_cmd_to_cmake( """ definitions = " ".join(self._build_option.get_definitions()) - if config.aot_inductor.compile_standalone: - if config.test_configs.use_libtorch: - add_target = f"add_library({self._target_name} STATIC)" - else: - add_target = f"add_executable({self._target_name} ${{CMAKE_CURRENT_SOURCE_DIR}}/main.cpp)" - else: - add_target = f"add_library({self._target_name} SHARED)" + target_library_type = ( + "STATIC" if config.aot_inductor.compile_standalone else "SHARED" + ) contents = textwrap.dedent( f""" @@ -1786,54 +1781,22 @@ def save_compile_cmd_to_cmake( project({self._target_name} LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) - # Set target - {add_target} - - """ - ) - - if ( - not config.aot_inductor.compile_standalone - or config.test_configs.use_libtorch - ): - # When compile_standalone is True, the generated cpp project should - # not use Torch. But for unit testing purpose, we need to use Torch here. - contents += textwrap.dedent( - """ - # May need to point CMAKE_PREFIX_PATH to the right torch location - find_package(Torch REQUIRED) - - """ - ) - # flags and macros here are mostly CPU specific. Not emitting them for GPU models - # will make the generated CMake file more portable and won't really hurt performance. - # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may - # be still needed. - contents += textwrap.dedent( - f""" - # Add macro definitions - target_compile_definitions({self._target_name} PRIVATE {definitions}) - - # Add compile flags - target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) - # Backend-specific flags - target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + # Set a shared library target + add_library({self._target_name} {target_library_type}) - """ - ) - else: - # When compile_standalone is True, use TorchStandalone instead of Torch - contents += textwrap.dedent( - """ - find_package(TorchStandalone REQUIRED) - # Set up include directories to find headers at the correct paths - target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS}) - target_include_directories(cos PRIVATE ${TorchStandalone_INCLUDE_DIRS}/standalone) + # Add macro definitions + target_compile_definitions({self._target_name} PRIVATE {definitions}) - """ - ) + # Add compile flags + target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + # Backend specific flags + target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + """ + ) if device_type == "cuda" and torch.version.hip is None: from torch._inductor.codecache import _nvcc_arch_as_compile_option @@ -1841,11 +1804,7 @@ def save_compile_cmd_to_cmake( contents += textwrap.dedent( f""" enable_language(CUDA) - set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED) - target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) - target_compile_definitions({self._target_name} PRIVATE USE_CUDA) - target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) find_program(OBJCOPY_EXECUTABLE objcopy) if(NOT OBJCOPY_EXECUTABLE) @@ -1874,7 +1833,7 @@ def save_compile_cmd_to_cmake( add_custom_command( OUTPUT ${{FATBIN_FILE}} COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} - -gencode arch=compute_{current_arch},code=compute_{current_arch} + -gencode arch=compute_80,code=compute_80 -gencode arch=compute_{current_arch},code=sm_{current_arch} DEPENDS ${{PTX_FILE}} ) @@ -1923,20 +1882,12 @@ def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> Non """ ) f.write(contents) - if asm_files: - f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") - f.write( - f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" - ) + f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") + f.write( + f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" + ) def save_link_cmd_to_cmake(self, cmake_path: str) -> None: - if ( - config.aot_inductor.compile_standalone - and not config.test_configs.use_libtorch - ): - # When compile_standalone is True, do not link with libtorch - return - lflags = " ".join(self._build_option.get_ldflags()) libs = " ".join(self._build_option.get_libraries()) contents = textwrap.dedent( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 22c533a5a03c4..aef81712d17eb 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3305,34 +3305,20 @@ def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, An Returns: dict[str, Any]: The possibly-updated `config_patches` dictionary. """ - - def patch_config( - config_patches: dict[str, Any], config_name: str, config_value: Any - ) -> None: - value = config_patches.get(config_name, getattr(config, config_name)) - if value is None: - config_patches[config_name] = config_value - elif not value: - raise RuntimeError( - f"Invalid config: {config_name}={config_value} when aot_inductor.compile_standalone is True." - ) - compile_standalone = config_patches.get( "aot_inductor.compile_standalone", config.aot_inductor.compile_standalone ) - # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing - config_patches = config_patches.copy() if compile_standalone: - # Standlaone AOTInductor means only generate cpp project for building a standalone binary - patch_config(config_patches, "aot_inductor.package_cpp_only", True) - # Standlaone AOTInductor needs to embed the kernel code in the binary - patch_config(config_patches, "aot_inductor.embed_kernel_binary", True) - # Default to use multi-arch kernel codegen - patch_config(config_patches, "aot_inductor.emit_multi_arch_kernel", True) - patch_config( - config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model" + package_cpp_only = config_patches.get( + "aot_inductor.package_cpp_only", config.aot_inductor.package_cpp_only ) - + if package_cpp_only is None: + config_patches = {**config_patches, "aot_inductor.package_cpp_only": True} + elif not package_cpp_only: + raise RuntimeError( + "compile_standalone=True requires package_cpp_only=True. " + "Please set aot_inductor.package_cpp_only=True in your inductor config." + ) return config_patches @@ -3361,3 +3347,11 @@ def is_valid_aoti_model_name() -> bool: ) return True + + +def aoti_model_name_from_config() -> str: + from torch._inductor import config + + model_name = config.aot_inductor.model_name_for_generated_files + model_name = "aoti_model" if model_name is None else model_name + return model_name From 0971637c115d2a41aff08d75deca02751a24f709 Mon Sep 17 00:00:00 2001 From: Alexander Novikov <79649566+novikov-alexander@users.noreply.github.com> Date: Tue, 22 Jul 2025 16:32:45 +0000 Subject: [PATCH 391/457] Fix torch.tensor warning in ONNX symbolic_opset10 export (#158835) Fix PyTorch tensor copying warning in ONNX export ## Problem PyTorch ONNX exporter was generating a warning about incorrect tensor copying method: ``` UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158835 Approved by: https://github.com/justinchuby --- torch/onnx/symbolic_opset10.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 469d7a80f77dc..0b8e2478ce339 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -513,7 +513,9 @@ def to_slice_input(list_or_value, default_value=None): if is_none_value(list_or_value) and default_value is not None: list_or_value = [default_value] - if isinstance(list_or_value, (list, torch.Tensor)): + if isinstance(list_or_value, torch.Tensor): + return g.op("Constant", value_t=list_or_value.clone().detach()) + elif isinstance(list_or_value, list): return g.op("Constant", value_t=torch.tensor(list_or_value)) rank = symbolic_helper._get_tensor_rank(list_or_value) From 52c294008ee764d1931d4f0c1aece984431e4596 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 21 Jul 2025 16:08:59 -0700 Subject: [PATCH 392/457] [hop] allow non fake inputs when check input alias and mutation (#158798) https://github.com/pytorch/pytorch/pull/154193 gets reverted due to a test failure. The root cause being that: an executorch pass turns int inputs into a scalar tensor in cond's subgraph. The pass have been around on the critical path of executorch since two years ago. Changing it would be difficult. So we just allow non-fake inputs for check input mutation and aliasing, which shoudn't affect the correctness of the analysis. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158798 Approved by: https://github.com/pianpwk --- torch/_higher_order_ops/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index bf3dc83f8608f..580d66551dd42 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -859,7 +859,11 @@ def _get_shape_env( # the runtime assertions for unbacked symbols. new_fake_mode = torch._subclasses.FakeTensorMode( shape_env=_get_shape_env(fake_args), - allow_non_fake_inputs=False, + # In executorch, there's an scalar_to_tensor pass that turns scalar inputs into a tensor constant + # e.g. add(a, 1) 1 is turned into a tensor, which becomes a constant tensor attribute in the graph. + # We allow non fake inputs for this purpose. This is fine for mutation detection purpose: + # inputs are all fake and all mutations/aliasing are still detected. + allow_non_fake_inputs=True, ) # We need to temporarily turn inference_mode off because # under inference mode, tensor version counter is not tracked. From 2a249f1967d29626fe6ac6a07f28440348d1cc93 Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 22 Jul 2025 10:40:18 -0700 Subject: [PATCH 393/457] We do support 3.14 This has been added a bit back. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b41ae87621f0f..523fed351b5cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "setuptools.build_meta" name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" -requires-python = ">=3.9,<3.14" +requires-python = ">=3.9" # TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 # FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. # TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed From 7d2ceaff21a6faf430470fc88642dad0e80386b4 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 21 Jul 2025 16:42:14 -0700 Subject: [PATCH 394/457] [dynamo] skip tracing functions registered in sys.monitoring (#158171) Fixes https://github.com/pytorch/pytorch/issues/158164 This was fixed by applying `skip_code_recursive` to any function registered to `sys.monitoring` (via `PyThreadState_GET()->interp->monitoring_callables`). This check is done whenever we attempt to set the eval frame callback from Python. Microbenchmark: `benchmarks/dynamo/microbenchmarks/overheads.py`: BEFORE: ``` requires_grad=False eager 7.1us (warmup=0.0s) compiled 24.6us (warmup=10.0s) requires_grad=True eager 8.9us (warmup=0.0s) compiled 57.8us (warmup=0.1s) inference_mode() eager 6.5us (warmup=0.0s) compiled 23.4us (warmup=0.1s) ``` AFTER: ``` requires_grad=False eager 7.0us (warmup=0.0s) compiled 23.2us (warmup=15.2s) requires_grad=True eager 9.0us (warmup=0.0s) compiled 55.1us (warmup=0.1s) inference_mode() eager 6.4us (warmup=0.0s) compiled 22.2us (warmup=0.1s) ``` Followup thought: how do we let users know that a frame is skipped because the code object is a callable registered to sys.monitoring? (or any other reason?) Differential Revision: [D78530528](https://our.internmc.facebook.com/intern/diff/D78530528) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158171 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 52 +++++++++++++++++++++++++++- torch/csrc/dynamo/eval_frame.c | 25 +++++++++++++ torch/csrc/dynamo/eval_frame.h | 5 +++ torch/csrc/dynamo/eval_frame_cpp.cpp | 15 ++++++++ torch/csrc/dynamo/eval_frame_cpp.h | 1 + 5 files changed, 97 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index cc702ad542cee..e0b2fdbf8611a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -43,7 +43,13 @@ import torch.utils._pytree as pytree from torch import nn from torch._dynamo.debug_utils import same_two_models -from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 +from torch._dynamo.testing import ( + CompileCounter, + rand_strided, + same, + skipIfNotPy312, + skipIfPy312, +) from torch._inductor.utils import fresh_cache from torch.nn import functional as F from torch.profiler import profile, ProfilerActivity @@ -7072,6 +7078,50 @@ def f(x, out): torch.compile(f, backend="eager", fullgraph=True)(x, out_res) self.assertEqual(out_ref, out_res) + @skipIfNotPy312 + def test_sys_monitoring(self): + found_dynamo = False + found_compiled_graph = False + compiled_graph = None + + def backend(gm, _): + nonlocal compiled_graph + compiled_graph = gm + return gm + + def callback(code, offset): + nonlocal found_dynamo + nonlocal found_compiled_graph + torch._dynamo.graph_break() + if ( + code + is torch._dynamo.symbolic_convert.InstructionTranslator.run.__code__ + ): + found_dynamo = True + elif compiled_graph and code is compiled_graph.__call__.__code__: + found_compiled_graph = True + + sys.monitoring.use_tool_id(0, "test") + old_callback = sys.monitoring.register_callback( + 0, sys.monitoring.events.PY_START, callback + ) + sys.monitoring.set_events(0, sys.monitoring.events.PY_START) + try: + + @torch.compile(backend=backend, fullgraph=True) + def fn(x): + return x + 1 + + fn(torch.ones(3)) + # sys.monitoring should still run in Python dynamo + self.assertTrue(found_dynamo) + # sys.monitoring should still run on the compiled graph + self.assertTrue(found_compiled_graph) + finally: + sys.monitoring.register_callback( + 0, sys.monitoring.events.PY_START, old_callback + ) + def test_unbind_copy_out(self): def f(eye, out): torch.unbind_copy(eye, out=out) diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 72bb8839bac35..7d00c7ba1abf3 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -34,6 +34,14 @@ void eval_frame_callback_set(PyObject* obj) { PyThread_tss_set(&eval_frame_callback_key, obj); } +#if IS_PYTHON_3_12_PLUS +const size_t sys_monitoring_num_callables = + sizeof((PyInterpreterState){0}.monitoring_callables) / sizeof(PyObject*); +PyObject** get_monitoring_callables(PyInterpreterState* interp) { + return (PyObject**)interp->monitoring_callables; +} +#endif + // 3.14 Not supported at all. See cpython_defs.c for hints #if !(IS_PYTHON_3_14_PLUS) @@ -582,6 +590,23 @@ static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) { "python enabled=%d and is run_only=%d", callback != Py_None, callback == Py_False); +#if IS_PYTHON_3_12_PLUS + // skip tracing sys.monitoring callables + if (callback != Py_None && callback != Py_False) { + PyInterpreterState* interp = PyThreadState_GET()->interp; + PyObject** monitoring_callables_flat = + (PyObject**)interp->monitoring_callables; + for (size_t i = 0; i < sys_monitoring_num_callables; ++i) { + PyObject* callable = monitoring_callables_flat[i]; + if (callable != NULL && PyFunction_Check(callable)) { + PyFunctionObject* func = (PyFunctionObject*)callable; + if (func->func_code != NULL) { + skip_code_recursive((PyCodeObject*)func->func_code); + } + } + } + } +#endif return set_eval_frame(callback, PyThreadState_GET(), module); } diff --git a/torch/csrc/dynamo/eval_frame.h b/torch/csrc/dynamo/eval_frame.h index 870603262ddb6..e8742e37fb635 100644 --- a/torch/csrc/dynamo/eval_frame.h +++ b/torch/csrc/dynamo/eval_frame.h @@ -11,6 +11,11 @@ PyObject* torch_c_dynamo_eval_frame_init(void); #endif +#if IS_PYTHON_3_12_PLUS +extern const size_t sys_monitoring_num_callables; +PyObject** get_monitoring_callables(PyInterpreterState* interp); +#endif + // All the eval APIs change in 3.11 so we need to decide which one to use on the // fly https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction #if IS_PYTHON_3_11_PLUS diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index e05de24259e0b..1d42722afaf9c 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -7,6 +7,10 @@ #include #include +#include +#include +#include + extern "C" { extern PyObject* guard_complete_hook; } @@ -335,3 +339,14 @@ PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { extra_state_set_exec_strategy(extra, strategy); Py_RETURN_NONE; } + +void skip_code_recursive(PyCodeObject* code) { + ExtraState* extra = get_extra_state(code); + if (extra == nullptr) { + extra = init_and_set_extra_state(code); + } + + FrameExecStrategy strategy = + FrameExecStrategy{FrameAction::SKIP, FrameAction::SKIP}; + extra_state_set_exec_strategy(extra, strategy); +} diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index ebbad47ef81b8..2f3587094f763 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -17,6 +17,7 @@ PyObject* dynamo__custom_eval_frame( PyObject* callback); PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); +void skip_code_recursive(PyCodeObject* code); #ifdef __cplusplus From 55ff4f85e9f31a3fced069ad526ced16e543cef3 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 22 Jul 2025 18:39:50 +0000 Subject: [PATCH 395/457] [FP8][CUTLASS] xFail `honor_sm_carveout` on `sm100` (#152378) CUTLASS only supports SM carveout via green contexts on `sm100` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152378 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/nWEIdia --- test/test_matmul_cuda.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 31f36681bc3a4..943c0ae1f5500 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1424,8 +1424,16 @@ def test_honor_sm_carveout(self) -> None: ] self.assertEqual(no_carveout, no_carveout_again) - self.assertNotEqual(no_carveout, carveout_66) - self.assertNotEqual(carveout_66, carveout_0) + capability = torch.cuda.get_device_capability() + if capability == (10, 0): + # expected failure + # CUTLASS only supports SM carveout via green contexts on SM100 + self.assertEqual(no_carveout, carveout_66) + self.assertEqual(carveout_66, carveout_0) + else: + # correct behavior + self.assertNotEqual(no_carveout, carveout_66) + self.assertNotEqual(carveout_66, carveout_0) def test_pack_uint4(self): """ From 56df025d5156a84c846a4e469c922c0f08fb3265 Mon Sep 17 00:00:00 2001 From: Jay Tsou Date: Tue, 22 Jul 2025 19:19:13 +0000 Subject: [PATCH 396/457] Add caching for `_rename_without_collisions` (#158594) Fixes #158357 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158594 Approved by: https://github.com/pianpwk --- torch/_export/utils.py | 65 ++++++++++++++++++++++++-------- torch/export/exported_program.py | 12 +++++- 2 files changed, 60 insertions(+), 17 deletions(-) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 1e2f84e5a3bd9..5594d72aa2f2c 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -8,6 +8,7 @@ import math import operator import re +from collections import defaultdict from collections.abc import Iterable from contextlib import contextmanager from inspect import ismethod, Parameter @@ -255,6 +256,8 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule): def _rename_without_collisions( name_map: dict[str, str], + find_available: dict[str, int], + used_names: set[str], orig_name: str, name: str, is_placeholder: bool = False, @@ -262,23 +265,32 @@ def _rename_without_collisions( """ Renames nodes to avoid name collisions, with suffixing. name_map: map from original name to new name + find_available: map prefix to available suffix + used_names: cache of used names orig_name: mapping key name: candidate name (potentially suffixed, e.g. mul_2) is_placeholder: if the node is a placeholder, avoid detecting suffix """ - if name in name_map.values(): - # non-placeholder nodes may be suffixed with the count - # instead of adding another suffix, we will try to increment it - match = re.match(r"(.*)_(\d+)", name) - if match and not is_placeholder: - name, n = match.group(1), int(match.group(2)) - else: - n = 0 - while (dup_name := f"{name}_{n + 1}") in name_map.values(): - n += 1 - name_map[orig_name] = dup_name - else: - name_map[orig_name] = name + match = re.match(r"(.*)_(\d+)", name) + key = name + + if match and not is_placeholder: + prefix, n = match.group(1), match.group(2) + key = prefix + + new_name = name + if new_name in used_names: + new_name = f"{key}_{find_available[key] + 1}" + + match = re.match(r"(.*)_(\d+)", new_name) + if match: + prefix, n = match.group(1), match.group(2) + if int(n) > find_available[prefix]: + find_available[prefix] = int(n) + + name_map[orig_name] = new_name + used_names.add(new_name) + return name_map[orig_name] @@ -867,6 +879,15 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs} +def _build_cache(name, find_available, used_names): + used_names.add(name) + match = re.match(r"(.*)_(\d+)", name) + if match: + prefix, n = match.group(1), match.group(2) + if int(n) > find_available[prefix]: + find_available[prefix] = int(n) + + def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: """ Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, @@ -874,6 +895,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: Different HOO subgraph types have different input schemas, so we first enumerate them and gather the top-level named placeholder nodes. """ + # gather all HOO subgraphs and their top-level named placeholder nodes subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] for node in gm.graph.nodes: @@ -897,12 +919,17 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: # propagate names for subgraph, hoo_phs in subgraph_ph_tuples: name_map: dict[str, str] = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() for i, node in enumerate(subgraph.graph.nodes): if i < len(hoo_phs): # placeholder, retain name name_map[node.name] = hoo_phs[i].name node.name = node.target = hoo_phs[i].name + _build_cache(node.name, find_available, used_names) else: # non-placeholder, check for collisions - node.name = _rename_without_collisions(name_map, node.name, node.name) + node.name = _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # recurse and recompile _name_hoo_subgraph_placeholders(subgraph) @@ -962,6 +989,8 @@ def _extract_pytree_key(x): raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") name_map: dict[str, str] = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() # map user input names with mod.forward() signature combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) @@ -978,6 +1007,8 @@ def _extract_pytree_key(x): if user_input_name: _rename_without_collisions( name_map, + find_available, + used_names, user_input_name, placeholder_prefixes[InputKind.USER_INPUT] + "_".join(_extract_pytree_key(x).lower() for x in arg_path), @@ -997,6 +1028,8 @@ def _extract_pytree_key(x): _rename_without_collisions( name_map, + find_available, + used_names, spec.arg.name, placeholder_prefixes[spec.kind] + base_name, is_placeholder=True, @@ -1015,7 +1048,9 @@ def _extract_pytree_key(x): for node in gm.graph.nodes: if node.op == "placeholder": continue - _rename_without_collisions(name_map, node.name, node.name) + _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # assign new node names for node in gm.graph.nodes: diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 85900dd5e5ea0..4aee86b099e12 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -40,6 +40,7 @@ import torch import torch.utils._pytree as pytree from torch._export.utils import ( + _build_cache, _collect_all_valid_cia_ops, _collect_and_set_constant_attrs, _collect_param_buffer_metadata, @@ -620,11 +621,18 @@ def update_arg(old_arg, new_ph): new_ph.name = new_ph.target = old_ph.name # handle name collisions with newly decomposed graph nodes - name_map = {ph.name: ph.name for ph in new_placeholders} + name_map = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() + for ph in new_placeholders: + name_map[ph.name] = ph.name + _build_cache(ph.name, find_available, used_names) for node in gm.graph.nodes: if node.op == "placeholder": continue - node.name = _rename_without_collisions(name_map, node.name, node.name) + node.name = _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # propagate names to higher order op subgraphs _name_hoo_subgraph_placeholders(gm) From 832ab990c99545ab5c80eefbc30ab2e14c617a4b Mon Sep 17 00:00:00 2001 From: Electron4444 Date: Tue, 22 Jul 2025 19:28:42 +0000 Subject: [PATCH 397/457] Use init_device_mesh API for select tests where possible (#158675) This addresses reviews made for: #158538 #108749 It interchanged all the specific DevideMesh constructor calls with the API provided by the test cases, to improve abstraction Pull Request resolved: https://github.com/pytorch/pytorch/pull/158675 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_api.py | 14 ++--- .../tensor/test_convolution_ops.py | 10 ++-- test/distributed/tensor/test_dtensor.py | 53 +++++++++---------- .../tensor/test_experimental_ops.py | 8 +-- test/distributed/tensor/test_init.py | 4 +- test/distributed/tensor/test_matrix_ops.py | 28 +++++----- test/distributed/tensor/test_optimizers.py | 21 ++++---- test/distributed/tensor/test_redistribute.py | 22 ++++---- test/distributed/tensor/test_tensor_ops.py | 50 ++++++++--------- .../distributed/_tensor/common_dtensor.py | 5 +- 10 files changed, 105 insertions(+), 110 deletions(-) diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index dd9f163ab4faf..a4efd6d5b6bed 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -48,7 +48,7 @@ def world_size(self) -> int: def test_distribute_tensor_rank(self): comm_mode = CommDebugMode() - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] for requires_grad in [True, False]: @@ -134,7 +134,7 @@ def test_distribute_tensor_errors(self): @with_comms def test_distribute_tensor_uneven_sharding(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() input_sizes_and_shard_dims = [ ((self.world_size * 3 + 1, 3, 3), 0), ((self.world_size * 3 + 2, 3, 3), 0), @@ -156,7 +156,7 @@ def test_distribute_tensor_uneven_sharding(self): @with_comms def test_distribute_module(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all linear modules on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type) shard_spec = [Shard(0)] @@ -219,7 +219,7 @@ def shard_fn(name, module, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -264,7 +264,7 @@ def replicate_input_fn(mod, inputs, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn_warning(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -292,7 +292,7 @@ def output_fn(outputs, device_mesh): @with_comms def test_distribute_module_casting(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # check DTensor casting dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()]) @@ -335,7 +335,7 @@ def test_distribute_module_casting(self): def test_distribute_module_meta(self): # If the model is too big, the user may first the create entire model on the meta device and then initialize # it on the device in the partition function. - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all parameters on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device="meta") diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index b6588c2ad95eb..d249a6d2ff772 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed import DeviceMesh from torch.distributed.tensor import ( distribute_module, distribute_tensor, @@ -48,7 +48,7 @@ def world_size(self) -> int: @with_comms def test_downsampling_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024) @@ -118,7 +118,7 @@ def test_downsampling_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_depthwise_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 256, 128, 256) @@ -186,9 +186,7 @@ def test_depthwise_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_conv_backward_none_grad_inp(self): - device_mesh = init_device_mesh( - device_type=self.device_type, mesh_shape=(self.world_size,) - ) + device_mesh = self.build_device_mesh() conv = nn.Conv2d(64, 64, 3, padding=1).train() x = torch.randn(1, 64, 32, 32) x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index b82661454bfc9..73f4b709103f3 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, @@ -61,7 +60,7 @@ def reset_parameters(self, *args, **kwargs): class DTensorTest(DTensorTestBase): @with_comms def test_dtensor_constructor(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) @@ -149,7 +148,7 @@ def test_modules_w_meta_dtensor(self): @with_comms def test_dtensor_stride(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] local_tensor = torch.randn(4, 8) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) @@ -172,7 +171,7 @@ def test_dtensor_stride(self): @with_comms def test_from_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -209,8 +208,7 @@ def test_from_local(self): @with_comms def test_from_local_uneven_sharding(self): - mesh_shape = (self.world_size,) - device_mesh = init_device_mesh(self.device_type, mesh_shape) + device_mesh = self.build_device_mesh() uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -235,8 +233,7 @@ def test_from_local_uneven_sharding(self): @with_comms def test_from_local_uneven_sharding_raise_error(self): - mesh_shape = (self.world_size,) - device_mesh = init_device_mesh(self.device_type, mesh_shape) + device_mesh = self.build_device_mesh() uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -270,7 +267,7 @@ def test_from_local_uneven_sharding_raise_error(self): @with_comms def test_from_local_negative_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(-1)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -278,7 +275,7 @@ def test_from_local_negative_dim(self): @with_comms def test_to_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) local_tensor_with_grad = torch.randn( 3, 3, device=self.device_type, requires_grad=True @@ -338,7 +335,7 @@ def test_to_local(self): @with_comms def test_to_local_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -363,7 +360,7 @@ def test_to_local_grad_hint(self): @with_comms def test_full_tensor_sync(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -374,7 +371,7 @@ def test_full_tensor_sync(self): @with_comms def test_full_tensor_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -387,7 +384,7 @@ def test_full_tensor_grad_hint(self): @with_comms def test_dtensor_new_empty_strided(self): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type) my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)]) new_strided_dtensor = my_dtensor.new_empty_strided( @@ -413,7 +410,7 @@ def test_dtensor_async_output(self): # Tests that if the output of some dtensor operations isn't used in any compute, # the output should be an AsyncCollectiveTensor (representing the fact that # we haven't synced the collective yet). - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(dt): dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True) @@ -453,7 +450,7 @@ def fn(dt): @with_comms def test_from_local_then_to_local(self): # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] # step 1. construct from construct local tensor @@ -485,7 +482,7 @@ def test_from_local_then_to_local(self): @with_comms def test_dtensor_spec_read_only_after_set(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -497,7 +494,7 @@ def test_dtensor_spec_read_only_after_set(self): @with_comms def test_dtensor_spec_hash(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) local_tensor2 = torch.randn(3, 3) @@ -517,7 +514,7 @@ def test_dtensor_spec_hash(self): @with_comms def test_dtensor_properties(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -571,7 +568,7 @@ def test_dtensor_save_load_import(self): @with_comms def test_shard_tensor(self): ws = self.world_size - device_mesh = DeviceMesh(self.device_type, list(range(ws))) + device_mesh = self.build_device_mesh() full_tensor = torch.arange(ws * ws).reshape(ws, ws) # Shard by row @@ -622,7 +619,7 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): # construct a cuda device mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # construct from a cpu local tensor with cuda device mesh # should automatically convert the dist tensor to cuda @@ -634,14 +631,14 @@ def test_dtensor_device_mesh_device_conversion(self): @with_comms def test_dtensor_api_device_mesh_context_manager(self): - with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: + with self.build_device_mesh() as mesh: placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local( local_tensor, device_mesh=mesh, placements=placements ) - with DeviceMesh(self.device_type, list(range(self.world_size))): + with self.build_device_mesh(): placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, placements=placements) @@ -651,7 +648,7 @@ def test_dtensor_api_device_mesh_context_manager(self): replica_tensor.size(), torch.Size([3 * self.world_size, 3]) ) - with DeviceMesh(self.device_type, torch.arange(self.world_size)): + with self.build_device_mesh(): placements = [Shard(0)] global_shape = torch.Size([3 * self.world_size, 3]) global_tensor = torch.randn(global_shape) @@ -837,7 +834,7 @@ def test_redistribute_sub_mesh(self): @with_comms def test_implicit_replication(self): - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() local_tensor1 = torch.ones(4, 3) sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) @@ -853,7 +850,7 @@ def test_implicit_replication(self): @with_comms def test_auto_implicit_replication(self): - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() local_tensor = torch.ones(self.world_size, 3, device=self.device_type) sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) @@ -879,7 +876,7 @@ def add_scalar_tensor_with_dtensor(): @with_comms def test_metadata_consistency_check(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] # Create a local tensor with specific metadata and check dtype change @@ -941,7 +938,7 @@ def _create_tensor(self, size): @with_comms def test_split_tensor_1D(self) -> None: - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() shard_placement = Shard(0) for size in range(8): diff --git a/test/distributed/tensor/test_experimental_ops.py b/test/distributed/tensor/test_experimental_ops.py index d5d7f2406adb2..ec4229a47b19c 100644 --- a/test/distributed/tensor/test_experimental_ops.py +++ b/test/distributed/tensor/test_experimental_ops.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist -from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate +from torch.distributed.tensor import distribute_tensor, Replicate from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -24,7 +24,7 @@ def world_size(self) -> int: @with_comms def test_slice(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -76,7 +76,7 @@ def test_slice(self): @with_comms def test_bernoulli(self): rank = dist.get_rank() - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -138,7 +138,7 @@ def test_bernoulli(self): @with_comms def test_nll(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] pred_list = torch.rand(ITER_TIME, 1024, 10) diff --git a/test/distributed/tensor/test_init.py b/test/distributed/tensor/test_init.py index 5409949548332..4212b6fc2c9bd 100644 --- a/test/distributed/tensor/test_init.py +++ b/test/distributed/tensor/test_init.py @@ -37,7 +37,7 @@ def world_size(self): def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): # 1d mesh test - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]] # even sharding @@ -132,7 +132,7 @@ def test_zeros(self): @with_comms def test_zeros_full_mesh(self): # construct a cuda device 1d mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() placements = [Shard(0)] size = [32, 3] dist_tensor = zeros(size, device_mesh=mesh, placements=placements) diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 523908c5e6bc4..e9baf2102b25d 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F -from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed import init_device_mesh from torch.distributed.tensor import ( distribute_tensor, DTensor, @@ -52,7 +52,7 @@ def scale_for_fp8( class DistMatrixOpsTest(DTensorTestBase): @with_comms def test_addmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -69,7 +69,7 @@ def test_addmm(self): @with_comms def test_addmm_empty_operand(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -86,7 +86,7 @@ def test_addmm_empty_operand(self): @with_comms def test_addmm_auto_redistribute(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] shard1_spec = [Shard(1)] replica_spec = [Replicate()] @@ -117,7 +117,7 @@ def test_addmm_auto_redistribute(self): @with_comms def test_mm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = Shard(0) shard1_spec = Shard(1) replica_spec = Replicate() @@ -152,7 +152,7 @@ def test_placement_comb( "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_scaled_mm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shrd0 = Shard(0) shrd1 = Shard(1) repl = Replicate() @@ -222,7 +222,7 @@ def test_scaled_mm(self): @with_comms def test_matmul(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() dim = 128 x = torch.randn(8, dim) A = torch.randn(dim, dim) @@ -241,7 +241,7 @@ def test_matmul(self): @with_comms def test_t(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_transpose = torch.randn(12, 8, requires_grad=True) @@ -255,7 +255,7 @@ def test_t(self): @with_comms def test_t_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() a = torch.randn(12, 8) b = torch.randn(8, 4) @@ -280,7 +280,7 @@ def test_t_partial(self): @with_comms @skip_unless_torch_gpu def test_baddbmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) @@ -344,7 +344,7 @@ def test_placement_comb( @with_comms def test_bmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) local_result = torch.bmm(mat1, mat2) @@ -389,7 +389,7 @@ def test_placement_comb( @with_comms @skip_unless_torch_gpu def test_scaled_dot_product_attention(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # bsz, n_heads, slen, head_dim query = torch.rand( @@ -492,7 +492,7 @@ def test_tensordot_shampoo(self): """ Create a simple test for Shampoo's use case. """ - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + device_mesh = self.build_device_mesh() local_a = torch.randn(4, 4) local_b = torch.randn(4, 15) @@ -545,7 +545,7 @@ def test_tensordot_shampoo(self): def test_grouped_mm(self, kwargs): # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) # More tests need to be added. - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() dtype = torch.bfloat16 inp = torch.rand( diff --git a/test/distributed/tensor/test_optimizers.py b/test/distributed/tensor/test_optimizers.py index 7e69f362183dd..c876f28e165b3 100644 --- a/test/distributed/tensor/test_optimizers.py +++ b/test/distributed/tensor/test_optimizers.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn from torch.distributed.tensor import ( - DeviceMesh, distribute_module, distribute_tensor, DTensor, @@ -89,7 +88,7 @@ def test_optimizer_foreach_supported_types_include_DTensor(self): @with_comms def test_adam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # lr as a Tensor is not supported for capturable=False and foreach=True adam_float_lr_configs = [ @@ -148,7 +147,7 @@ def test_adam_1d_sharding(self): @with_comms def test_adamw_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # lr as a Tensor is not supported for capturable=False and foreach=True adamw_float_lr_configs = [ @@ -224,7 +223,7 @@ def test_adamw_1d_sharding(self): @with_comms def test_sgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() sgd_configs = [ {"lr": 0.1, "foreach": False}, @@ -264,7 +263,7 @@ def test_sgd_1d_sharding(self): @with_comms def test_adagrad_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adagrad_configs = [ {"lr": 0.1, "foreach": False}, @@ -320,7 +319,7 @@ def test_adagrad_1d_sharding(self): @with_comms def test_RMSprop_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() RMSprop_configs = [ {"lr": 0.1, "foreach": False}, @@ -387,7 +386,7 @@ def test_RMSprop_1d_sharding(self): @with_comms def test_adadelta_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adadelta_configs = [ {"lr": 0.1, "foreach": False}, @@ -431,7 +430,7 @@ def test_adadelta_1d_sharding(self): @with_comms def test_nadam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() nadam_configs = [ {"lr": 0.1, "foreach": False}, @@ -468,7 +467,7 @@ def test_nadam_1d_sharding(self): @with_comms def test_radam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() radam_configs = [ {"lr": 0.1, "foreach": False}, @@ -508,7 +507,7 @@ def test_radam_1d_sharding(self): @with_comms def test_adamax_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adamax_configs = [ {"lr": 0.1, "foreach": False}, @@ -552,7 +551,7 @@ def test_adamax_1d_sharding(self): @with_comms def test_asgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() asgd_configs = [ {"lr": 0.1, "foreach": False}, diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index b56f32dbcaea4..fe07b0dd6a241 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -40,7 +40,7 @@ def world_size(self): @parametrize("dtype", [torch.float32, torch.cfloat]) def test_shard_to_replicate_forward_backward(self, dtype): # 1) test shard -> replicate forward - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -82,7 +82,7 @@ def test_shard_to_replicate_forward_backward(self, dtype): @with_comms def test_replicate_to_replicate_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) @@ -111,7 +111,7 @@ def test_replicate_to_replicate_forward_backward(self): @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_replicate_to_local_partial_grad(self, dtype): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype @@ -132,7 +132,7 @@ def test_replicate_to_local_partial_grad(self, dtype): @with_comms def test_replicate_to_shard_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -185,7 +185,7 @@ def test_partial_to_replicate_forward_backward(self, dtype): # placement (i.e. user can't reshard to partial), we do allow # replicate to partial internally, and also partial to replicate # backward should work as expected - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_local = torch.ones( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype ) @@ -220,7 +220,7 @@ def test_partial_to_replicate_forward_backward(self, dtype): @with_comms def test_replicate_to_replicate_forward_backward_datatype_conversion(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] forward_datatypes = [ @@ -277,7 +277,7 @@ def test_replicate_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_shard_to_replicate_forward_backward_datatype_conversion(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] shard_dim_and_input_sizes = [ @@ -349,7 +349,7 @@ def test_shard_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_replicate_to_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) partial_spec = Partial() replica_spec = Replicate() @@ -398,7 +398,7 @@ def test_replicate_to_partial(self): @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_partial_to_shard(self, dtype): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_spec = [Partial()] my_rank = device_mesh.get_rank() @@ -453,7 +453,7 @@ def test_partial_to_shard(self, dtype): @with_comms def test_redistribute_negative_shard_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) shard_spec = [Shard(1)] shard_minus_spec = [Shard(-1)] @@ -491,7 +491,7 @@ def test_redistribute_uneven_sharding(self): @parametrize("dtype", [torch.float32, torch.cfloat]) def test_redistribute_shard_dim_change(self, dtype): # test 1d device mesh - mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh_1d = self.build_device_mesh() data_to_test = [ # evenly sharded case torch.randn((8, 8), device=self.device_type, dtype=dtype), diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index d62da27d43393..0e75748be8a31 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -25,7 +25,7 @@ class DistTensorOpsTest(DTensorTestBase): @with_comms def test_aten_contiguous(self): # this op not covered by dtensor_ops - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() self._test_op( mesh, lambda x: torch.ops.aten.contiguous(x), @@ -34,7 +34,7 @@ def test_aten_contiguous(self): @with_comms def test_detach(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_detach = torch.randn(12, 8, requires_grad=True) @@ -44,7 +44,7 @@ def test_detach(self): @with_comms def test_clone(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() specs = [[Replicate()], [Shard(0)]] tensor_to_clone = torch.randn(12, 8, requires_grad=True) for spec in specs: @@ -95,7 +95,7 @@ def test_copy_(self): @with_comms def test_contiguous(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() tensor = torch.rand(3, 5, 6, requires_grad=True) sharding = [Shard(0)] dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) @@ -121,7 +121,7 @@ def test_contiguous(self): @with_comms def test_inplace_op(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) dt_to_mul = dt_to_add.clone() @@ -148,7 +148,7 @@ def test_inplace_op(self): @with_comms def test_op_out_variant(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) expected_dt = sharded_dt_input.clone() + 3 @@ -169,7 +169,7 @@ def test_op_out_variant(self): @with_comms def test_empty_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -180,7 +180,7 @@ def test_empty_like(self): @with_comms def test_fill_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -192,7 +192,7 @@ def test_fill_inplace(self): @with_comms def test_full_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -203,7 +203,7 @@ def test_full_like(self): @with_comms def test_ones_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -214,7 +214,7 @@ def test_ones_like(self): @with_comms def test_ones_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -227,7 +227,7 @@ def test_ones_like_partial_sum(self): @with_comms def test_fill_inplace_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -243,7 +243,7 @@ def test_fill_inplace_partial_sum(self): @with_comms def test_zeros_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -256,7 +256,7 @@ def test_zeros_like_partial_sum(self): @with_comms def test_zero_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -268,7 +268,7 @@ def test_zero_inplace(self): @with_comms def test_zeros_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -320,7 +320,7 @@ def test_stack(self): @with_comms def test_equal(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor_1 = torch.ones(4, 4) @@ -370,7 +370,7 @@ def _test_op(self, mesh, op_call, *args, **kwargs): @with_comms def test_new_full(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() global_tensor = torch.randn(12, 8) @@ -397,7 +397,7 @@ def test_new_full(self): @with_comms def test_new_empty_strided(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() shard_dim = 1 @@ -442,7 +442,7 @@ def test_new_empty_strided(self): @with_comms def test_scatter(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index/src replicated, output replicated @@ -476,7 +476,7 @@ def test_scatter(self): @with_comms def test_gather(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index replicated, output replicated @@ -527,7 +527,7 @@ def test_gather(self): @with_comms def test_index(self): meshes = [ - DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh + self.build_device_mesh(), # 1D mesh # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh ] @@ -677,7 +677,7 @@ def test_index_put_tensor(self): @with_comms def test_where_type_promotion(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh + mesh = self.build_device_mesh() # 1D mesh specs = [[Shard(0)], [Replicate()]] for spec in specs: @@ -689,7 +689,7 @@ def test_where_type_promotion(self): @with_comms def test_dtensor_dtype_conversion(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype local_tenor = torch.randn(2, 8, dtype=torch.bfloat16) @@ -723,7 +723,7 @@ def test_dtensor_dtype_conversion(self): @with_comms def test_slice(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh + mesh = self.build_device_mesh() # 1D mesh comm_mode = CommDebugMode() shard_spec = [Shard(1)] @@ -761,7 +761,7 @@ def test_split_on_partial(self): def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int): torch.manual_seed(self.rank) - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() partial_tensor = torch.randn(8, 8, device=self.device_type) partial_dt = DTensor.from_local( diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index c922e6993af33..94bfead8a0c03 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -18,6 +18,7 @@ DeviceMesh, distribute_tensor, DTensor, + init_device_mesh, Placement, Replicate, Shard, @@ -352,7 +353,7 @@ def backend(self) -> str: return backend def build_device_mesh(self) -> DeviceMesh: - return DeviceMesh(self.device_type, list(range(self.world_size))) + return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init) -> None: if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: @@ -483,7 +484,7 @@ def device_type(self) -> str: return DEVICE_TYPE def build_device_mesh(self): - return DeviceMesh(self.device_type, list(range(self.world_size))) + return init_device_mesh(self.device_type, (self.world_size,)) def setUp(self) -> None: super().setUp() From 659bfbf44329c44b9451e197e2b5eb83d48311d2 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 22 Jul 2025 19:40:53 +0000 Subject: [PATCH 398/457] Revert "We do support 3.14" (#158856) Reverting to fix lint This reverts commit 2a249f1967d29626fe6ac6a07f28440348d1cc93. An emergency fix since the change needed to fix this is a little more complex than expected (see https://github.com/pytorch/pytorch/pull/158853 for reference) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158856 Approved by: https://github.com/Camyll, https://github.com/atalman --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 523fed351b5cc..b41ae87621f0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "setuptools.build_meta" name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.9,<3.14" # TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 # FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. # TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed From c917c63282c467ef942c99da3ce4fa57bceba603 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Tue, 22 Jul 2025 19:45:35 +0000 Subject: [PATCH 399/457] [ROCm][tunableop] UT tolerance increase for matmul_small_brute_force_tunableop at FP16 (#158788) TunableOp will sometimes find a less precise solution due to the small input vectors used in this UT. Bumping op tolerance to eliminate flakiness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158788 Approved by: https://github.com/jeffdaily --- test/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index 8712d65bb493c..f49db43b4ff26 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4762,6 +4762,7 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype): @onlyCUDA @skipCUDAIfNotRocm # Skipping due to SM89 OOM in CI, UT doesn't do much on NV anyways @dtypes(*floating_types_and(torch.half)) + @precisionOverride({torch.float16: 1e-1}) # TunableOp may occasionally find less precise solution def test_matmul_small_brute_force_tunableop(self, device, dtype): # disable tunableop buffer rotation for all tests everywhere, it can be slow # We set the TunableOp numerical check environment variable here because it is From 767791943d5dd325bf3572c45886f7156e69dd5b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 22 Jul 2025 19:55:01 +0000 Subject: [PATCH 400/457] [ONNX] Set default opset to 20 (#158802) Bump default opset to 20, which is a newer opset and the max torchscript exporter supports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158802 Approved by: https://github.com/titaiwangms --- test/onnx/exporter/test_api.py | 3 ++- torch/onnx/_constants.py | 2 +- torch/onnx/_internal/exporter/_compat.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 993da83058684..9a8a171b5fe29 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -281,10 +281,11 @@ def forward(self, input): # Use GELU activation function return torch.nn.functional.gelu(input, approximate="tanh") - input = torch.randn(1, 3, 4, 4) + input = (torch.randn(1, 3, 4, 4),) onnx_program_op18 = torch.onnx.export( GeluModel(), input, + opset_version=18, dynamo=True, ) all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py index b3c386b701d92..87ff04da8cd1e 100644 --- a/torch/onnx/_constants.py +++ b/torch/onnx/_constants.py @@ -6,7 +6,7 @@ ONNX_MIN_OPSET = 7 ONNX_MAX_OPSET = 23 ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 -ONNX_DEFAULT_OPSET = 18 +ONNX_DEFAULT_OPSET = 20 ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index c3a0f26b227d3..cf83aa4061543 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -10,9 +10,9 @@ from typing import Any, Callable, TYPE_CHECKING import torch +from torch.onnx import _constants as onnx_constants from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import ( - _constants, _core, _dynamic_shapes, _onnx_program, @@ -50,7 +50,7 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, - opset_version: int | None = _constants.TORCHLIB_OPSET, + opset_version: int | None = onnx_constants.ONNX_DEFAULT_OPSET, custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -70,7 +70,7 @@ def export_compat( legacy_export_kwargs: dict[str, Any] | None = None, ) -> _onnx_program.ONNXProgram: if opset_version is None: - opset_version = _constants.TORCHLIB_OPSET + opset_version = onnx_constants.ONNX_DEFAULT_OPSET if isinstance(model, torch.export.ExportedProgram): # We know the model is already exported program, so the args, kwargs, and dynamic_shapes From 37ded2ac906c2a15f5613e134d7eeb8a8f953bb7 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 22 Jul 2025 19:58:45 +0000 Subject: [PATCH 401/457] Using torch.accelerator in comm_mode_features_example.py and visualize_sharding_example.py (#157317) Continuation of https://github.com/pytorch/pytorch/pull/153213 . @guangyey @kwen2501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157317 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- .../tensor/examples/comm_mode_features_example.py | 9 ++++----- .../tensor/examples/visualize_sharding_example.py | 11 +++++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index 3a8ca45b8aaff..8625a3f7dd1d7 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -27,11 +27,10 @@ def get_device_type() -> str: - return ( - "cuda" - if torch.cuda.is_available() and torch.cuda.device_count() >= 4 - else "cpu" - ) + device_type = "cpu" + if torch.accelerator.device_count() >= 4: + device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + return device_type c10d_functional = torch.ops.c10d_functional diff --git a/torch/distributed/tensor/examples/visualize_sharding_example.py b/torch/distributed/tensor/examples/visualize_sharding_example.py index 7152c928d2f28..7c0ab3adfffae 100644 --- a/torch/distributed/tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/tensor/examples/visualize_sharding_example.py @@ -18,6 +18,9 @@ rank = int(os.environ["RANK"]) +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + + def section(msg: str) -> None: if rank == 0: rich.print(rich.rule.Rule(msg)) @@ -31,7 +34,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: section("[bold]1D Tensor; 1D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (4,)) +m = dist.init_device_mesh(device_type, (4,)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -43,7 +46,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 1D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (4,)) +m = dist.init_device_mesh(device_type, (4,)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -59,7 +62,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]1D Tensor; 2D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (2, 2)) +m = dist.init_device_mesh(device_type, (2, 2)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), @@ -79,7 +82,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 2D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (2, 2)) +m = dist.init_device_mesh(device_type, (2, 2)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), From e17538022a81c453276cb27468223ddbe4e3e883 Mon Sep 17 00:00:00 2001 From: "Goswami, Subrata" Date: Tue, 22 Jul 2025 20:14:02 +0000 Subject: [PATCH 402/457] Making input dynamically adjust. (#157324) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/157324 Approved by: https://github.com/Skylion007, https://github.com/d4l3k --- test/distributed/fsdp/test_fsdp_uneven.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/fsdp/test_fsdp_uneven.py b/test/distributed/fsdp/test_fsdp_uneven.py index f74f2ed94ebb4..d0094ce1de71f 100644 --- a/test/distributed/fsdp/test_fsdp_uneven.py +++ b/test/distributed/fsdp/test_fsdp_uneven.py @@ -45,7 +45,7 @@ def _get_ref_results(self, device, model, input, my_lr): def test_one_iteration(self, device): """Test FSDP with uneven divide of parameter shards.""" model = Linear(3, 3, bias=False) - input = torch.rand(8, 3) + input = torch.rand(self.world_size, 3) my_lr = 0.1 ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( From 6499420e45298bad5ef0241d0f04f029825abc60 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 22 Jul 2025 10:28:32 -0700 Subject: [PATCH 403/457] [DeviceMesh] Make the repr shorter when debug ENV not set (#158822) Users want a shorter repr so this PR is trying to address that when TORCH_DISTRIBUTED_DEBUG is not set to DETAIL. Feedback and discussion is welcomed. Somehow I found that torch.set_printoptions is global, so I am hesitated to use it. Now the print is like image or image or image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158822 Approved by: https://github.com/wz337, https://github.com/wconstab, https://github.com/xmfan --- torch/distributed/device_mesh.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index dfd07c707a7f9..035b297afe419 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -644,11 +644,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def __repr__(self) -> str: device_mesh_repr = ( - f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" - if not self.mesh_dim_names - else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + f"({', '.join(f'{k}={v}' for k, v in zip(self.mesh_dim_names, self.mesh.shape))})" + if self.mesh_dim_names + else f"{tuple(self.mesh.shape)}" ) - return device_mesh_repr + device_mesh_repr = f"DeviceMesh({device_mesh_repr}, device: '{self.device_type}', stride: {self.mesh.stride()}" + # We only print the mesh tensor if the debug mode is turned on. + if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": + device_mesh_repr += f", Mesh: {self.mesh.tolist()}" + return f"{device_mesh_repr})" def __hash__(self): # lazily compute hash From 823e2238934fe65133741ad0ab9debaacfd4abe8 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 22 Jul 2025 20:32:30 +0000 Subject: [PATCH 404/457] [ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903) Fixes #156012 This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton. After discussion we are not going to release AOTriton 0.10.1 to fix this due to * Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release. * AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks Pull Request resolved: https://github.com/pytorch/pytorch/pull/156903 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .../tensor/experimental/_attention.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 457624bd6a674..f8e0984d65142 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -44,6 +44,25 @@ class _RotateMethod(Enum): logger = logging.getLogger(__name__) +def _need_scaling() -> bool: + if hasattr(torch.version, "hip") and torch.version.hip is not None: + gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName + _is_ck_supported = False + for arch in ["gfx942", "gfx950"]: + if arch in gcn_arch_name: + _is_ck_supported = True + # Check the function exists + _preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library + _CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"] + # Note: it is possible that CK is selected but not compiled in the binary. + if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND: + # Unsure about CK's behavior, keep logsumexp untouched + return False + return True + else: + return False + + class _DispatchMode(Enum): MONKEY_PATCH = auto() TORCH_FUNCTION = auto() @@ -446,6 +465,8 @@ def _templated_ring_attention( is_causal=is_causal_behavior.value, **kwargs, ) + if _need_scaling(): + logsumexp *= 0.6931471805599453 sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest From ddd74d10fcc3e51b9df438faec95f8f207cb1c37 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 21 Jul 2025 21:54:39 -0700 Subject: [PATCH 405/457] More fixes to `MakeTensor::computeStorageSize()` (#158813) Followup after https://github.com/pytorch/pytorch/pull/158690 that fixessimilar logic if `strides` are not explicitly specified Expanded testing to cover both cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/158813 Approved by: https://github.com/ZainRizvi, https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #158690 --- aten/src/ATen/templates/Functions.cpp | 2 +- aten/src/ATen/test/basic.cpp | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index b3c2164f1707e..f210402e543aa 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -75,7 +75,7 @@ Tensor TensorMaker::make_tensor() { } auto storage_size = size * itemsize; if (storage_offset_) { - storage_size += storage_offset_.value(); + storage_size += storage_offset_.value() * itemsize; } return storage_size; } diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 0e4f461cfd9a4..0937de4552821 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -519,6 +519,15 @@ TEST(BasicTest, BasicStdTestCPU) { } TEST(BasicTest, TestForBlobResizeCPU) { + // Checks that for_blob can correctly create tensors with non-empty offset and resize them + std::array storage; + std::iota(storage.begin(), storage.end(), 1); + auto t = at::for_blob(storage.data(), {3,}).storage_offset(3).options(c10::TensorOptions(kInt)).make_tensor(); + auto te = *at::expand_size(t, {3, 3}); + ASSERT_EQ(te[1][1].item(), 5); +} + +TEST(BasicTest, TestForBlobStridesResizeCPU) { // Checks that for_blob can correctly create tensors with non-empty offset and resize them std::array storage; std::iota(storage.begin(), storage.end(), 1); From e44e05f7ae3ac81675d4636475f562ee1fee9a9c Mon Sep 17 00:00:00 2001 From: Sandeep Narendranath Karjala Date: Wed, 9 Jul 2025 17:14:03 -0700 Subject: [PATCH 406/457] [dynamo] Move skipIf decorator to class level in test_fx_graph_runnable (#157594) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157594 Approved by: https://github.com/xmfan ghstack dependencies: #157162 --- test/dynamo/test_fx_graph_runnable.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 74d17dd6825f7..0164b6f9c680d 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -50,6 +50,7 @@ def forward(self, x): return x +@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") class FxGraphRunnableTest(TestCase): def setUp(self): super().setUp() @@ -92,7 +93,6 @@ def _exec_and_verify_payload(self): ) # basic tests - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_basic_tensor_add(self): def f(x): return x + 1 @@ -100,7 +100,6 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() @@ -109,7 +108,6 @@ def f(a, b): torch.compile(f)(a, b) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_scalar_multiply(self): def f(x): return x * 2 @@ -118,7 +116,6 @@ def f(x): self._exec_and_verify_payload() # testing dynamic shapes - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_dynamic_shapes_run(self): def f(x): return (x @ x.transpose(0, 1)).relu() @@ -130,7 +127,6 @@ def f(x): torch.compile(f)(a) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_broadcast_add_dynamic(self): def f(x, y): return x + y * 2 @@ -143,7 +139,6 @@ def f(x, y): torch.compile(f)(x, y) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_basic(self): model = ToyModel(input_size=8, hidden_size=16, output_size=4) model.eval() # Set to eval mode to avoid dropout randomness @@ -152,7 +147,6 @@ def test_toy_model_basic(self): torch.compile(model)(x) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_batch_processing(self): model = ToyModel(input_size=12, hidden_size=24, output_size=6) model.eval() @@ -161,7 +155,6 @@ def test_toy_model_batch_processing(self): torch.compile(model)(x) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_dynamic_batch(self): model = ToyModel(input_size=10, hidden_size=20, output_size=5) model.eval() From fd47401536a35fa5fbf68f0b67fdbf94628f5c23 Mon Sep 17 00:00:00 2001 From: Panagiotis Kourdis Date: Tue, 22 Jul 2025 21:01:38 +0000 Subject: [PATCH 407/457] [doc] Updates to distributed.md for XCCL backend (#155834) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155834 Approved by: https://github.com/guangyey, https://github.com/AlannaBurke, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- docs/source/distributed.md | 65 ++++++++++++++------------- torch/distributed/device_mesh.py | 2 +- torch/distributed/distributed_c10d.py | 4 +- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 95820f8244c54..9762e79c7ea3b 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -20,39 +20,41 @@ for a brief introduction to all features related to distributed training. ## Backends -`torch.distributed` supports three built-in backends, each with +`torch.distributed` supports four built-in backends, each with different capabilities. The table below shows which functions are available -for use with CPU / CUDA tensors. +for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU +while for XCCL to XPU GPU. + MPI supports CUDA only if the implementation used to build PyTorch supports it. ```{eval-rst} -+----------------+-----------+-----------+-----------+ -| Backend | ``gloo`` | ``mpi`` | ``nccl`` | -+----------------+-----+-----+-----+-----+-----+-----+ -| Device | CPU | GPU | CPU | GPU | CPU | GPU | -+================+=====+=====+=====+=====+=====+=====+ -| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ ++----------------+-----------+-----------+-----------+-----------+ +| Backend | ``gloo`` | ``mpi`` | ``nccl`` | ``xccl`` | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| Device | CPU | GPU | CPU | GPU | CPU | GPU | CPU | GPU | ++================+=====+=====+=====+=====+=====+=====+=====+=====+ +| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ ``` ### Backends that come with PyTorch @@ -81,8 +83,9 @@ In the past, we were often asked: "which backend should I use?". - Rule of thumb - - Use the NCCL backend for distributed **GPU** training - - Use the Gloo backend for distributed **CPU** training. + - Use the NCCL backend for distributed training with CUDA **GPU**. + - Use the XCCL backend for distributed training with XPU **GPU**. + - Use the Gloo backend for distributed training with **CPU**. - GPU hosts with InfiniBand interconnect diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 035b297afe419..370bab11b4dbd 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -1011,7 +1011,7 @@ def init_device_mesh( required for distributed communications behind the scene. Args: - device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu". Passing in a device type with a GPU index, such as "cuda:0", is not allowed. mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2bc99a51cd64d..d96cc61a5ac7d 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1568,12 +1568,12 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on build-time configurations, valid values include ``mpi``, ``gloo``, - ``nccl``, ``ucc``, or one that is registered by a third-party + ``nccl``, ``ucc``, ``xccl`` or one that is registered by a third-party plugin. Since 2.6, if ``backend`` is not provided, c10d will use a backend registered for the device type indicated by the `device_id` kwarg (if provided). The known default registrations today are: ``nccl`` - for ``cuda``, ``gloo`` for ``cpu``. + for ``cuda``, ``gloo`` for ``cpu``, ``xccl`` for ``xpu``. If neither ``backend`` nor ``device_id`` is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or ``cpu``). From a626dc8f1604ce24dadbcdf257f92d0b15bf9367 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Tue, 22 Jul 2025 21:35:57 +0000 Subject: [PATCH 408/457] [AOTI] windows package load dev (#158671) changes: 1. add extract file fail handler for Windows develop. 2. normalize more file paths. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158671 Approved by: https://github.com/angelayi, https://github.com/desertfire --- .../inductor/aoti_package/model_package_loader.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index ed4d302bb7b34..ccaeaa9b775be 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -53,13 +53,11 @@ std::string normalize_path_separator(const std::string& orig_path) { On Windows, when we input: "C:\Users\Test\file.txt", the output should be: "C:/Users/Test/file.txt". And then, we can process the output like on Linux. */ -#ifdef _WIN32 std::string normalized_path = orig_path; +#ifdef _WIN32 std::replace(normalized_path.begin(), normalized_path.end(), '\\', '/'); - return normalized_path; -#else - return orig_path; #endif + return normalized_path; } bool file_exists(const std::string& path) { @@ -548,7 +546,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << found_filenames[1]; } - temp_dir_ = create_temp_dir(); + temp_dir_ = normalize_path_separator(create_temp_dir()); std::string so_filename; std::string cpp_filename; @@ -581,6 +579,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } + output_path_str = normalize_path_separator(output_path_str); + LOG(INFO) << "Extract file: " << filename_str << " to " << output_path_str; From 04a393507b7e3fea0ef98024ebc14061173369f0 Mon Sep 17 00:00:00 2001 From: AaronWang04 Date: Tue, 22 Jul 2025 22:25:40 +0000 Subject: [PATCH 409/457] Fused RMSNorm implementation (#153666) Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel, https://github.com/albanD --- .../functorch/BatchRulesDecompositions.cpp | 1 + .../src/ATen/native/cuda/layer_norm_kernel.cu | 595 +++++++++++++----- aten/src/ATen/native/layer_norm.cpp | 84 ++- aten/src/ATen/native/layer_norm.h | 6 + .../src/ATen/native/mps/operations/RMSNorm.mm | 13 +- aten/src/ATen/native/native_functions.yaml | 8 +- test/distributed/tensor/test_math_ops.py | 98 +++ ...asDecompTest.test_has_decomposition.expect | 1 - .../check_forward_backward_compatibility.py | 2 + test/test_decomp.py | 29 +- tools/autograd/derivatives.yaml | 5 + torch/_decomp/__init__.py | 1 + torch/_decomp/decompositions.py | 75 +++ torch/csrc/autograd/FunctionsManual.cpp | 189 ++++++ torch/csrc/autograd/FunctionsManual.h | 23 + .../aoti_torch/generated/c_shim_cpu.h | 1 + .../aoti_torch/generated/c_shim_cuda.h | 1 + .../aoti_torch/generated/c_shim_mps.h | 2 +- .../aoti_torch/generated/c_shim_xpu.h | 1 + torch/distributed/tensor/_ops/_math_ops.py | 147 +++-- torch/overrides.py | 1 + 21 files changed, 1047 insertions(+), 236 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 4b66b30b62e7f..d58d436c511d1 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); + m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index bdb169e26b142..082f4f0a1af4d 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -50,7 +50,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,12 +84,17 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + if constexpr (!rms_norm){ + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + } else { + rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); + } + } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -103,11 +108,15 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; + if constexpr (!rms_norm){ + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; + } else { + Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; + } } } @@ -119,40 +128,48 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + if constexpr (!rms_norm){ + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + } else{ + return {0.f, curr_sum.sigma2 + val * val, 0}; + } } -__device__ +template __device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + if constexpr (!rms_norm){ + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; } else { - mean = U(0); - sigma2 = U(0); + return {0.f, dataB.sigma2 + dataA.sigma2, 0}; } - return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -171,14 +188,13 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), - WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -199,7 +215,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -216,7 +232,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -231,7 +247,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -254,34 +270,48 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + if constexpr (!rms_norm){ + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } else { + out.val[ii] = rstd_val * static_cast(data.val[ii]); + } } } Y_vec[i] = out; } if (thrx == 0) { - mean[i1] = wd.mean; + if constexpr (!rms_norm){ + mean[i1] = wd.mean; + } rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -296,7 +326,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -306,11 +336,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -321,7 +351,10 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - const T_ACC mean_val = mean[i1]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[i1]; + } const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -337,26 +370,39 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } + } + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); } - - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - buf[0] = stats_x1; + if constexpr (!rms_norm){ + buf[0] = stats_x1; + } buf[1] = stats_x2; } __syncthreads(); - stats_x1 = buf[0]; + if constexpr (!rms_norm){ + stats_x1 = buf[0]; + } stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -367,15 +413,20 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } + f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -387,7 +438,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -396,7 +447,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -409,7 +460,10 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - const T_ACC mean_val = mean[bIdx]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[bIdx]; + } const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -441,8 +495,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } @@ -451,19 +509,29 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else{ + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } // Reduction in Shared Memory - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + } stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - reduce_buf[0] = stats_x1; + if constexpr (!rms_norm){ + reduce_buf[0] = stats_x1; + } reduce_buf[1] = stats_x2; } __syncthreads(); - stats_x1 = reduce_buf[0]; + if constexpr (!rms_norm){ + stats_x1 = reduce_buf[0]; + } stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -485,8 +553,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -501,15 +573,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -525,17 +601,25 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + if constexpr (!rms_norm){ + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + } else { + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index])) * static_cast(rstd[i]); + } } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - db[j] = sum2; + if constexpr (!rms_norm){ + db[j] = sum2; + } } } } @@ -545,7 +629,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -569,7 +654,9 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; + if constexpr (!rms_norm){ + warp_mean = mean[mean_index + lane_id]; + } warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -596,10 +683,14 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; + if constexpr (!rms_norm){ + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } else{ + dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } } } @@ -608,7 +699,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -629,10 +721,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -654,7 +746,8 @@ template __global__ void @@ -679,7 +772,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -687,11 +780,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -706,7 +799,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db) { + if (db && !rms_norm) { db[thread_y * N + thread_x] = db_sum; } } @@ -752,7 +845,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db) { + if (db && !rms_norm) { db[out_index] = reg_db; } } @@ -763,7 +856,8 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction, +bool rms_norm> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -779,7 +873,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -790,7 +884,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -806,7 +900,7 @@ if (aligned_grid) { template +int rows_per_block_y, bool rms_norm> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -829,16 +923,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -876,19 +970,21 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined()) { + if (dbeta->defined() && !rms_norm) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); + if constexpr (!rms_norm){ + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); + } } } else { // We are in the normal case where M is not that large. @@ -896,18 +992,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -936,7 +1032,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -958,7 +1054,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -968,7 +1064,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -987,7 +1083,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = mean->data_ptr(); + T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1002,14 +1098,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1037,7 +1133,29 @@ void LayerNormKernelImpl( }); } -template __device__ +void RmsNormKernelImpl( + const Tensor& X, + const Tensor& gamma, + int64_t M, + int64_t N, + double eps, + Tensor* Y, + Tensor* rstd) { +AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormKernelImpl", + [&]() { + using acc_t = acc_type; + // rms_norm = true + LayerNormKernelImplInternal( + // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True + X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); + }); +} + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1055,7 +1173,10 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1080,7 +1201,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1098,7 +1219,11 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1114,7 +1239,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1140,9 +1265,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1181,7 +1306,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1206,7 +1331,9 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - sum_beta += part_grad_beta_ptr[warp_offset*N]; + if constexpr (!rms_norm){ + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } } } @@ -1224,7 +1351,9 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx+nbsize3]; + if constexpr (!rms_norm){ + sum_beta += buf[read_idx+nbsize3]; + } } __syncthreads(); } @@ -1235,12 +1364,14 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - grad_beta[i2] = sum_beta; + if constexpr (!rms_norm){ + grad_beta[i2] = sum_beta; + } } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1254,7 +1385,10 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = mean[i1]; + T_ACC c_mean = 0; + if constexpr (!rms_norm){ + c_mean = mean[i1]; + } const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1267,21 +1401,31 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + if constexpr (!rms_norm){ + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1292,25 +1436,33 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2*wrt_i] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*wrt_i] = sum_loss1; + } buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2*read_i]; + if constexpr (!rms_norm){ + sum_loss1 += buf[2*read_i]; + } sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - buf[2*threadIdx.x] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*threadIdx.x] = sum_loss1; + } buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - sum_loss1 = buf[2*threadIdx.x]; + if constexpr (!rms_norm){ + sum_loss1 = buf[2*threadIdx.x]; + } sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1323,8 +1475,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1333,8 +1489,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1344,7 +1504,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1358,7 +1518,9 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - TORCH_CHECK(mean.numel() == M); + if constexpr (!rms_norm){ + TORCH_CHECK(mean.numel() == M); + } TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1384,7 +1546,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1396,7 +1558,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1410,13 +1572,12 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); - if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1432,7 +1593,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1456,7 +1617,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1470,7 +1631,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1480,7 +1641,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1508,8 +1669,29 @@ void LayerNormBackwardKernelImpl( }); } +void RMSNormBackwardKernelImpl( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor* dX, + Tensor* dgamma) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormBackwardKernelImpl", + [&]() { + LayerNormBackwardKernelImplInternal( + dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); + }); +} + } // namespace + std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1638,6 +1820,113 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } +/* RMSNorm is implemented by reusing layer_norm's kernels */ +std::tuple _fused_rms_norm_cuda( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps){ + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); + double eps_val; + if (acc_type == at::ScalarType::Float) { + eps_val = eps.value_or(std::numeric_limits::epsilon()); + } else { + eps_val = eps.value_or(std::numeric_limits::epsilon()); + } + + Tensor Y = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); + + if (M > 0) { + RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); + } + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (const auto idx: c10::irange(axis)) { + stat_shape.push_back(input_shape[idx]); + } + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { + stat_shape.push_back(1); + } + + rstd = rstd.view(stat_shape); + + return std::make_tuple(std::move(Y), std::move(rstd)); +} + + +std::tuple _fused_rms_norm_backward_cuda( + const Tensor& dY, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt /* optional */, + std::array grad_input_mask) { + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + Tensor dX; + Tensor dgamma; + if (grad_input_mask[0]) { + dX = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::native::zeros_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (M > 0 && N > 0) { + RMSNormBackwardKernelImpl( + dY, *X, rstd, *gamma, M, N, &dX, &dgamma); + } + return std::make_tuple(std::move(dX), std::move(dgamma)); +} + REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index da6bb5fec39e8..950aa99a9aabe 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,30 +261,11 @@ std::tuple math_native_layer_norm( return outputs; } -Tensor rms_norm_symint( +std::tuple rms_norm_composite( const Tensor& input, - c10::SymIntArrayRef normalized_shape, + IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - -#ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_nested = input.is_nested() || weight.is_nested(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); - const bool is_input_fp = isFloatingType(input.scalar_type()); - const bool is_weight_fp = isFloatingType(weight.scalar_type()); - - if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); - } - } -#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -321,10 +302,67 @@ Tensor rms_norm_symint( upcasted_result = upcasted_result.mul(weight_opt.value()); } - return upcasted_result; + // if nested do not make contiguous + if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); }); + return std::make_tuple( + std::get<0>(result).type_as(input), // Cast normalized result to original input type + std::get<1>(result) // rsqrt_val + ); +} + + +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + const std::optional eps) { + + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + + // composite fallback for channels last + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } - return result.type_as(input); + // composite fallback for complex datatypes + if(input.is_complex()){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) { + TORCH_WARN_ONCE( + "Mismatch dtype between input and module: input dtype = ", input.dtype(), + ", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation" + ); + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + #ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + + if (!(GradMode::is_enabled() && any_inputs_require_grad)) { + return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + } + + if (input.device().type() == DeviceType::MPS){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + #endif + return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); } + } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0181f35fd6ed4..0debe942dd0a6 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,6 +106,12 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +std::tuple rms_norm_composite( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 71128297d5bfc..7948b5acd8e93 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,7 +19,14 @@ #include #endif -Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { +std::tuple _fused_rms_norm_mps(const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt, + const std::optional eps) { + const Tensor weight = weight_opt.value().contiguous(); + const int64_t normalized_ndim = normalized_shape.size(); + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c } }); - return output; + return std::make_tuple(output, Tensor()); } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e0dc1b616013e..4778aee27f423 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3314,9 +3314,15 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor +- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) dispatch: + CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + CompositeImplicitAutograd: rms_norm_composite + +- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index e13e0c0266b8b..93ce80f18ee15 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -572,6 +572,104 @@ def forward(self, tokens): f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" ) + @with_comms + def test_rms_norm_bwd(self): + device_mesh = self.build_device_mesh() + + # NLP example from pytorch docs + batch, sentence_length, embedding_dim = 20, 5, 10 + norm_shape_idx_list = list(range(3)) + shard_dims = [0] # non-first dimensional sharding is not supported + elementwise_affine_list = [False, True] + test_config_list = list( + itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + ) + + # normalized shape is a torch.Size object + for shard_dim, norm_idx, elementwise_affine in test_config_list: + x = torch.rand( + batch, + sentence_length, + embedding_dim, + device=self.device_type, + requires_grad=True, + ) + normalized_shape = x.shape[norm_idx:] + rms_norm = torch.nn.RMSNorm( + normalized_shape, + elementwise_affine=elementwise_affine, + device=self.device_type, + ) + rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type) + + def _replicate_fn(name, module, device_mesh): + for name, param in module.named_parameters(): + if name == "weight": + param_dist = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Replicate()]) + ) + module.register_parameter(name, param_dist) + + rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn) + + if elementwise_affine: + self.assertEqual( + rms_norm_local.weight, rms_norm_dist.weight.full_tensor() + ) + + x_local = x.detach().clone().requires_grad_(True) + x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + self.assertEqual(x_local, x_dist.full_tensor()) + + y_local = rms_norm_local(x_local) + # make sure that backward rms norm does not introduce extra collectives + comm_mode = CommDebugMode() + with comm_mode: + y_dist = rms_norm_dist(x_dist) + y_dist.sum().backward() + + # TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm + # see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679 + expected_fwd_comm = 0 if shard_dim < norm_idx else 2 + + self.assertEqual( + sum(comm_mode.comm_module_counts["Global"]["forward"].values()), + expected_fwd_comm, + f"comm count={comm_mode.get_total_counts()}, " + f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", + ) + + self.assertEqual(y_local, y_dist.full_tensor()) + + # backward step + y_local.sum().backward() + + expected_bwd_comm = 0 if shard_dim < norm_idx else 1 + + self.assertEqual( + sum(comm_mode.comm_module_counts["Global"]["backward"].values()), + expected_bwd_comm, + f"comm count={comm_mode.get_total_counts()}, " + f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", + ) + + if elementwise_affine: + # if input is sharded on any outer dimension, the gradient of weight + # should be Partial + dim_map = x_dist._spec.dim_map + outer_dims = range(norm_idx) + needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) + self.assertEqual( + is_tensor_partial(rms_norm_dist.weight.grad._spec), + needs_reduction, + ) + self.assertEqual( + rms_norm_local.weight.grad, + rms_norm_dist.weight.grad.full_tensor(), + ) + + self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) + @with_comms def test_topk(self): device_mesh = self.build_device_mesh() diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 042959c22cd4a..a590713ad0f83 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -374,7 +374,6 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional -aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index d6cf2df4343ff..5a962dfa57c05 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,6 +139,8 @@ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), + # Previously MPS_only did not support backward + ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_decomp.py b/test/test_decomp.py index 5d641e32e422e..dcd6e69af997c 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import tf32_off +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,6 +1226,33 @@ def f(x, w, b): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) + @onlyCUDA + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_cuda(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e2419aab268b1..f0349c2484b61 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,6 +1267,11 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") +- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" + result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) + result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) + - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abb94b109cc0c..8e9796d2f7c1b 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f93a0bf84fb4b..832928ebf8aee 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm_backward.default) +def _fused_rms_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + rstd: Tensor, + weight: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + + grad_out_cast = grad_out.to( + computation_dtype, memory_format=torch.contiguous_format + ) + input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) + weight_cast = ( + weight.to(computation_dtype, memory_format=torch.contiguous_format) + if weight is not None + else None + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + ) + + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + + x_hat = input_cast * rstd + + if output_mask[0]: + sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) + d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd + + if output_mask[1] and weight_cast is not None: + d_weight_full_shape = grad_out_cast * x_hat + if len(outer_dim_indices) > 0: + d_weight = torch.sum( + d_weight_full_shape, dim=outer_dim_indices, keepdim=False + ) + else: + d_weight = d_weight_full_shape + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + ) + + def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 908a980cfee9c..8e13d4267edb5 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5023,6 +5023,103 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int normalized_ndim = normalized_shape.size(); + const int axis = input_ndim - normalized_ndim; + + int64_t N_rms = 1; + for (int i = 0; i < normalized_ndim; ++i) { + N_rms *= input_shape[axis + i]; + } + + Tensor dX; + Tensor dgamma; + + std::vector rstd_view_shape = rstd.sizes().vec(); + for (int i = 0; + i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); + ++i) { + rstd_view_shape.push_back(1); + } + Tensor rstd_broadcast = rstd.view(rstd_view_shape); + Tensor rstd_pow3 = rstd_broadcast.pow(3); + Tensor grad_x_hat; + + if (dY.defined()) { + if (weight.defined()) { + grad_x_hat = dY * weight; + } else { + grad_x_hat = dY; + } + } + + if (grad_input_mask[0]) { + Tensor dX_from_dY_path; + Tensor dX_from_drstd_path; + + std::vector inner_sum_dims; + inner_sum_dims.reserve(normalized_ndim); + for (int i = 0; i < normalized_ndim; ++i) { + inner_sum_dims.push_back(axis + i); + } + + if (dY.defined() && grad_x_hat.defined()) { + Tensor sum_input_times_grad_x_hat = + sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); + dX_from_dY_path = rstd_broadcast * grad_x_hat - + (input * rstd_pow3 / static_cast(N_rms)) * + sum_input_times_grad_x_hat; + } + + if (drstd.defined()) { + Tensor drstd_broadcast = drstd.view(rstd_view_shape); + dX_from_drstd_path = + -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; + } + + if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { + dX = dX_from_dY_path + dX_from_drstd_path; + } else if (dX_from_dY_path.defined()) { + dX = dX_from_dY_path; + } else if (dX_from_drstd_path.defined()) { + dX = dX_from_drstd_path; + } + } + + if (grad_input_mask[1] && weight.defined()) { + if (dY.defined()) { + Tensor x_hat = input * rstd_broadcast; + Tensor dgamma_full_shape = dY * x_hat; + + if (axis > 0) { + std::vector outer_sum_dims; + outer_sum_dims.reserve(axis); + for (int i = 0; i < axis; ++i) { + outer_sum_dims.push_back(i); + } + dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); + } else { + dgamma = dgamma_full_shape; + } + } + } + + return std::make_tuple(dX, dgamma); +} + std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + + Tensor result_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + result_t = (input_t)*rstd_p + (input_p)*rstd_t; + } else { + result_t = input_t * rstd_p; + auto temp = input_p * rstd_t; + result_t += temp; + } + + std::optional result_p = std::nullopt; + if (weight_p.defined()) { + result_p = std::optional(input_p * rstd_p); + } + + return _affine_jvp( + result_p, + result_t, + weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, + weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, + Tensor()); +} + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + return rstd_t; +} + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0b659973ec345..96864e165a95a 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -826,6 +826,15 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask); + std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -965,6 +974,20 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2aa09cb802ecd..aced2b2f539de 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index e0607f984b3d0..92d30ded855f8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index a5d654c518840..c76ee685c25da 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 243bfb5fc87aa..6fc51bd0c8f8d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,6 +13,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 59ca7aed9bdf8..9e875936c2643 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -904,34 +904,48 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: return output_strategy -@register_op_strategy( - [aten.native_layer_norm_backward.default], - schema_info=RuntimeSchemaInfo(2), -) -def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: +def _common_norm_backward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common backward strategy logic for layer_norm and rms_norm.""" # backward op does not need to validate the mesh since forward op has already done it mesh = op_schema.get_mesh_from_args(validate=False) - # args must be: grad_out, input, normalized_shape, mean, rstd, - # weight, bias, output_mask. For None weight and bias, their - # corresponding objects will be None as well. - - assert len(op_schema.args_schema) == 8 - ( - grad_out_strategy, - input_strategy, - normalized_shape, - mean_strategy, - rstd_strategy, - weight_strategy, - bias_strategy, - output_mask, - ) = op_schema.args_schema + if not rms_norm: + # layer_norm args: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + else: + # rms_norm args: grad_out, input, normalized_shape, rstd, + assert len(op_schema.args_schema) == 6 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + rstd_strategy, + weight_strategy, + output_mask, + ) = op_schema.args_schema + mean_strategy = None + bias_strategy = None assert isinstance(grad_out_strategy, OpStrategy) assert isinstance(input_strategy, OpStrategy) - assert isinstance(mean_strategy, OpStrategy) assert isinstance(rstd_strategy, OpStrategy) + if mean_strategy is not None: + assert isinstance(mean_strategy, OpStrategy) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) normalized_size = normalize_to_torch_size(normalized_shape) @@ -939,9 +953,12 @@ def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: axis = input_ndim - len(normalized_size) outer_dims = list(range(axis)) - assert isinstance(output_mask, list) and len(output_mask) == 3 + if not rms_norm: + assert isinstance(output_mask, list) and len(output_mask) == 3 + else: + assert isinstance(output_mask, list) and len(output_mask) == 2 - # output triple: (d_input, d_weight, d_bias) + # output tuple: (d_input, d_weight[, d_bias]) out_tuple_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): # args for OpSpec @@ -982,10 +999,14 @@ def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: generate_redistribute_costs(input_strategy, input_target_spec) ) - # arg: mean, rstd - mean_src_spec = mean_strategy.strategies[idx].output_spec - input_specs_list.append(mean_src_spec) - redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + # arg: mean + if not rms_norm: + assert mean_strategy is not None # mypy fix + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + + # arg: rstd rstd_src_spec = rstd_strategy.strategies[idx].output_spec input_specs_list.append(rstd_src_spec) redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) @@ -1001,6 +1022,7 @@ def _add_target_input_spec(strategy) -> DTensorSpec: # arg: weight # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False) if weight_strategy is not None: weight_src_spec = _add_target_input_spec(weight_strategy) # TODO: now d_weight spec follows input spec w/ a reduction. @@ -1020,36 +1042,39 @@ def _add_target_input_spec(strategy) -> DTensorSpec: ) output_specs_list.append(weight_out_spec if output_mask[1] else None) else: - assert output_mask[1] is False, ( - "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." - ) + if not rms_norm: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + else: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward." + assert output_mask[1] is False, error_msg output_specs_list.append(None) # arg: bias # d_bias = sum(grad_out, outer_dim, keepdim=False) - if bias_strategy is not None: - bias_src_spec = _add_target_input_spec(bias_strategy) - # d_bias spec follows a reduction over grad_out - inp_placements = _replicate_dims_start_at( - grad_out_target_spec.placements, axis - ) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, grad_out_target_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - bias_out_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=bias_src_spec.tensor_meta, - ) - output_specs_list.append(bias_out_spec if output_mask[2] else None) - else: - assert output_mask[2] is False, ( - "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." - ) - output_specs_list.append(None) + if not rms_norm: + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + assert output_mask[2] is False, ( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) out_tuple_strategy.strategies.append( OpSpec( @@ -1062,6 +1087,22 @@ def _add_target_input_spec(strategy) -> DTensorSpec: return out_tuple_strategy +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema, rms_norm=True) + + @register_op_strategy( [aten.topk.default], schema_info=RuntimeSchemaInfo(2), diff --git a/torch/overrides.py b/torch/overrides.py index cb67931fab691..046171ef6c5c6 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1, From fc5a404eb1e28fdb39f5b91d6378699d3b042749 Mon Sep 17 00:00:00 2001 From: Yahaya Suleiman Date: Tue, 22 Jul 2025 22:45:28 +0000 Subject: [PATCH 410/457] [gtest][listing] fixing caffe2:verify_api_visibility - main (#158229) Summary: Remove the custom main from this test file Test Plan: https://www.internalfb.com/intern/testinfra/testrun/9570149303161031 Rollback Plan: Reviewed By: patskovn Differential Revision: D78015676 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158229 Approved by: https://github.com/Skylion007 --- aten/src/ATen/test/verify_api_visibility.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/test/verify_api_visibility.cpp b/aten/src/ATen/test/verify_api_visibility.cpp index 5878ed352e5b1..c6d2fcc6fb865 100644 --- a/aten/src/ATen/test/verify_api_visibility.cpp +++ b/aten/src/ATen/test/verify_api_visibility.cpp @@ -20,4 +20,8 @@ #error "CAFFE2_STATIC_LINK_CUDA should not be visible in public headers" #endif -auto main() -> int {} +#include + +TEST(VerifyApiVisibility, Test) { + ASSERT_EQ(1, 1); +} From badfebf29e46c3e41d7cf54a7a807865a90277b0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Jul 2025 23:04:16 +0000 Subject: [PATCH 411/457] Revert "[Inductor] Expose decomposeK knobs as envvars (#158745)" This reverts commit eac777c4f46b381106f2f2b78fe05b506f8c558c. Reverted https://github.com/pytorch/pytorch/pull/158745 on behalf of https://github.com/jeffdaily due to sorry but rocm CI is broken due to this PR ([comment](https://github.com/pytorch/pytorch/pull/158745#issuecomment-3105071170)) --- test/inductor/test_max_autotune.py | 55 +++++++----------------------- torch/_inductor/config.py | 14 ++------ torch/_inductor/utils.py | 33 ++++++++++-------- 3 files changed, 35 insertions(+), 67 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index a04017459fc69..6245b89f4eca3 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -50,12 +50,7 @@ aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import ( - fresh_cache, - get_k_splits, - run_and_get_code, - use_decompose_k_choice, -) +from torch._inductor.utils import fresh_cache, run_and_get_code from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck @@ -1503,7 +1498,6 @@ def misses(): self.assertEqual(hits(), 4) self.assertEqual(misses(), 4) - @fresh_cache() @skipIfXpu @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" @@ -1512,42 +1506,19 @@ def misses(): max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, + disable_decompose_k=True, ) - @parametrize("num_decompose_k_splits", (0, 5, 20)) - @parametrize("decompose_k_threshold", (8, 16)) - def test_max_autotune_decompose_k_envvars( - self, num_decompose_k_splits, decompose_k_threshold - ): - shapes = [(32, 32, 32768), (32, 32, 256)] - for M, N, K in shapes: - get_k_splits.cache_clear() - use_decompose_k_choice.cache_clear() - a = torch.randn(M, K, dtype=torch.float16, device="cuda") - b = torch.randn(K, N, dtype=torch.float16, device="cuda") - - with config.patch( - { - "triton.num_decompose_k_splits": num_decompose_k_splits, - "triton.decompose_k_threshold": decompose_k_threshold, - } - ): - compiled_func = torch.compile(lambda a, b: a @ b) - _, code = run_and_get_code(compiled_func, a, b) - - decompose_count = 0 - for codegen in code: - if "benchmark_decompose_k_mm" in codegen: - decompose_count += 1 - - if ( - K // M < decompose_k_threshold - or K // N < decompose_k_threshold - or num_decompose_k_splits == 0 - ): - self.assertEqual(decompose_count, 0) - else: - self.assertTrue(decompose_count > 0) - self.assertTrue(decompose_count <= num_decompose_k_splits) + def test_max_autotune_disable_decompose_K(self): + M, N, K = (32, 32, 32768) + + a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) + b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) + + compiled_func = torch.compile(lambda a, b: a @ b) + out, code = run_and_get_code(compiled_func, a, b) + + for codegen in code: + FileCheck().check_not("decompose_k").run(codegen) @skipIfXpu @unittest.skipIf( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ae2ee6a574c73..2404f397ba54c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -425,6 +425,9 @@ def prologue_fusion_enabled() -> bool: # enable slow autotuning passes to select gemm algorithms max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" +# disable decomposek autotune choice for gemm +disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" + # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 @@ -1342,17 +1345,6 @@ class triton: # Note: it may also need to be used with config.compile_threads = 1 disallow_failing_autotune_kernels_TESTING_ONLY = False - # specify number of splits to autotune on for decompose_k. 0 disables decompose_k - num_decompose_k_splits = int( - os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") - ) - - # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables - # it as an autotuning choice for all matmuls - decompose_k_threshold = int( - os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") - ) - class aot_inductor: """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index aef81712d17eb..d95642b75f9d1 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1665,15 +1665,20 @@ def _use_cutlass_for_op(op_name: str) -> bool: return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] +decompose_k_threshold = 32 + +# To limit compile time +k_splits_limit = 5 + +# Hand-tuned +default_k_splits = [16, 32, 64, 128, 256] + _IntLike: TypeAlias = Union[int, sympy.Expr] -@functools.cache def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: from torch._inductor.virtualized import V - decompose_k_threshold = config.triton.decompose_k_threshold - return ( not torch.version.hip and V.graph.sizevars.statically_known_true( @@ -1684,21 +1689,15 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode # TODO: Support AOTI for decomposeK and not V.graph.cpp_wrapper + and not config.disable_decompose_k ) @functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: - # To limit compile time - k_splits_limit = config.triton.num_decompose_k_splits - - # Hand-tuned - default_k_splits = [16, 32, 64, 128, 256] # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: return default_k_splits - elif k_splits_limit == 0: - return [] if (isinstance(m, sympy.Expr) and not m.is_number) or ( isinstance(n, sympy.Expr) and not n.is_number @@ -1738,10 +1737,15 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: if config.max_autotune_gemm_search_space == "EXHAUSTIVE": return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - - best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # Otherwise, conform results to k_splits_limit - return best_splits[:k_splits_limit] + # If the # of power of 2 divisors are greater than k_splits_limit, return all + # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) + # should never be a massive amount + if len(pow_of_2_divisors) >= k_splits_limit: + return pow_of_2_divisors + else: + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] @functools.cache @@ -2016,6 +2020,7 @@ def call(self, *args: Any, **kwargs: Any) -> None: self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) # Skip all the actual compiling. + nonlocal save_output_code save_output_code(wrapper_code.value) if kernel_code: save_output_code(kernel_code.value) From 6100ed457c9bf19dd80e0d53301c7bae691da8d3 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Tue, 22 Jul 2025 23:19:38 +0000 Subject: [PATCH 412/457] [ROCm] Improve Type Safety of C10_WARP_SIZE (#158271) # Background The `C10_WARP_SIZE`, although always be `32` on CUDA platform, varies across different AMD GPUs. Therefore, to correctly refer this value, the host code must be a variable instead of a literal defined by macro, or a `constexpr int`. This PR may cause more compiler errors for third party code on AMD GPU, which is intentional. Having a fixed `C10_WARP_SIZE` value on host code for AMD GPU only defers compile time error to runtime. This PR is recommended to be included as part of Release Notes to describe an API change for whoever uses this macro. Users are recommended to use `C10_WARP_SIZE` directly, which adapts for various scenarios, or define a macro to use `C10_WARP_SIZE`. Assignment of this macro to symbols shared by host/device code causes problems on ROCM platform. (See the fix at `aten/src/ATen/native/cuda/layer_norm_kernel.cu` for a concrete example) # Behaviors * If compiling with HIPCC (i.e `defined(__HIPCC__)`): + Define `C10_WARP_SIZE` to be non-`constexpr` `at::cuda::warp_size()` for host-compilation pass (as compared to `static constexpr int C10_WARP_SIZE = 1;` set in 04bd7e6850e8efec77994963ffee87549555b9c3) + Define `C10_WARP_SIZE` to be a function returning `constexpr int` `64` for `__GFX9__`, and `32` otherwise, for device-compilation pass - `__GFX8__` is also 64 but we do not support any GFX8 GPU. * If not compiling with HIPCC: + Define `C10_WARP_SIZE` to be non-constexpr `at::cuda::warp_size()` # `constexpr` variant for host code For host-compilation cases where a `constexpr` value is needed for warp size (eg. launch bounds), use `C10_WARP_SIZE_STATIC`, which is defined as `64`. This macro follows the pre 04bd7e6850e8efec77994963ffee87549555b9c3 behavior of `C10_WARP_SIZE` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158271 Approved by: https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 5 ++++ .../sparse/cuda/SparseCUDAApplyUtils.cuh | 4 +++ torch/headeronly/macros/Macros.h | 30 ++++++++++++++----- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 082f4f0a1af4d..940680eb3682f 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -33,7 +33,12 @@ namespace at::native { namespace { constexpr int kCUDANumThreads = 256; +#ifdef USE_ROCM +// C10_WARP_SIZE is not constexpr for host code. +#define kWarpSize C10_WARP_SIZE +#else constexpr unsigned int kWarpSize = C10_WARP_SIZE; +#endif constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c9412d74e9cda..693ca536a3198 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel( // `if constexpr` when CUDA codes will be compiled under C++-17, see // gh-56055 for blockers. template +#ifdef USE_ROCM +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_STATIC*4) +#else C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) +#endif __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, diff --git a/torch/headeronly/macros/Macros.h b/torch/headeronly/macros/Macros.h index 0c02cce309dc8..1e07ab0446e8c 100644 --- a/torch/headeronly/macros/Macros.h +++ b/torch/headeronly/macros/Macros.h @@ -318,16 +318,32 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; // depending on the target device, and then always set it to 64 for host code. // Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we // set it to something unreasonable to trigger obvious host code errors. -#if defined(__HIP_DEVICE_COMPILE__) + +namespace at::cuda { +TORCH_CUDA_CPP_API int warp_size(); +} +#ifdef __HIPCC__ +static inline int __host__ C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} + +static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() { #if defined(__GFX9__) -static constexpr int C10_WARP_SIZE = 64; + return 64; #else // __GFX9__ -static constexpr int C10_WARP_SIZE = 32; + return 32; #endif // __GFX9__ -#else -static constexpr int C10_WARP_SIZE = 1; -#endif // __HIP_DEVICE_COMPILE__ -#else +} +#else // __HIPCC__ +static inline int C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} +#endif // __HIPCC__ + +#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL()) +#define C10_WARP_SIZE_STATIC 64 + +#else // defined(USE_ROCM) #define C10_WARP_SIZE 32 #endif From cab96b587944d324dde2528d4b1ec5819bc52ce9 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Tue, 22 Jul 2025 23:41:44 +0000 Subject: [PATCH 413/457] [tests] Reduce sizes of unnecessarily large tensors to reduce OOM flakes (#158456) Downsizes several tensors that were massively oversized to test the problem at hand, to reduce test flaking. Fixes #126867 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158456 Approved by: https://github.com/desertfire --- test/inductor/test_max_autotune.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 6245b89f4eca3..096e924a47826 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -815,9 +815,9 @@ def test_non_contiguous_input_mm(self): Check https://github.com/pytorch/pytorch/issues/125437 for more details. """ x = rand_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) - y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x, y): @@ -830,9 +830,9 @@ def f(x, y): def test_non_contiguous_input_addmm(self): b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE) x = rand_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) - y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x, y): @@ -844,10 +844,10 @@ def f(x, y): def test_non_contiguous_input_bmm(self): x = rand_strided( - (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (1, 50257, 2048), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) y = rand_strided( - (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE + (1, 2048, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE ) @torch.compile(mode="max-autotune") @@ -861,16 +861,12 @@ def f(x, y): # TODO: fix accuracy failure of the triton template on XPU. # and enable this test case. @skipIfXpu - @unittest.skipIf( - os.getenv("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1", - "OOM when running with TORCHINDUCTOR_CPP_WRAPPER https://github.com/pytorch/pytorch/issues/126867", - ) def test_non_contiguous_input_mm_plus_mm(self): - x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) - y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) + x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) + y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) - x2 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) - y2 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) + x2 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) + y2 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x1, y1, x2, y2): From d3f9107d682d2fb554d09f1f14e81850de793e7a Mon Sep 17 00:00:00 2001 From: albanD Date: Tue, 22 Jul 2025 23:58:55 +0000 Subject: [PATCH 414/457] Remove top limit for cpython version and fix lint appropriately. (#158853) As per title. Sorry for the churn in the main commit. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158853 Approved by: https://github.com/seemethere, https://github.com/Skylion007, https://github.com/jingsh, https://github.com/malfet, https://github.com/ZainRizvi --- pyproject.toml | 2 +- tools/linter/adapters/pyproject_linter.py | 106 +++++++++++----------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b41ae87621f0f..523fed351b5cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "setuptools.build_meta" name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" -requires-python = ">=3.9,<3.14" +requires-python = ">=3.9" # TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 # FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. # TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed diff --git a/tools/linter/adapters/pyproject_linter.py b/tools/linter/adapters/pyproject_linter.py index e3a61cc634238..5e046509319f2 100644 --- a/tools/linter/adapters/pyproject_linter.py +++ b/tools/linter/adapters/pyproject_linter.py @@ -128,16 +128,16 @@ def check_file(filename: str) -> list[LintMessage]: ), ) ] - if f"{python_major}.{large_minor}" in supported_python_versions: - return [ - format_error_message( - filename, - message=( - "'project.requires-python' must specify a maximum version, " - f"but found {requires_python!r}." - ), - ) - ] + # if f"{python_major}.{large_minor}" in supported_python_versions: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.requires-python' must specify a maximum version, " + # f"but found {requires_python!r}." + # ), + # ) + # ] classifiers = project.get("classifiers") if not ( @@ -158,49 +158,49 @@ def check_file(filename: str) -> list[LintMessage]: ) ] - python_version_classifiers = [ - c - for c in classifiers - if ( - c.startswith("Programming Language :: Python :: ") - and not c.endswith((f":: {python_major}", f":: {python_major} :: Only")) - ) - ] - if python_version_classifiers: - python_version_classifier_set = set(python_version_classifiers) - supported_python_version_classifier_set = { - f"Programming Language :: Python :: {v}" - for v in supported_python_versions - } - if python_version_classifier_set != supported_python_version_classifier_set: - missing_classifiers = sorted( - supported_python_version_classifier_set - - python_version_classifier_set - ) - extra_classifiers = sorted( - python_version_classifier_set - - supported_python_version_classifier_set - ) - if missing_classifiers: - return [ - format_error_message( - filename, - message=( - "'project.classifiers' is missing the following classifier(s):\n" - + "\n".join(f" {c!r}" for c in missing_classifiers) - ), - ) - ] - if extra_classifiers: - return [ - format_error_message( - filename, - message=( - "'project.classifiers' contains extra classifier(s):\n" - + "\n".join(f" {c!r}" for c in extra_classifiers) - ), - ) - ] + # python_version_classifiers = [ + # c + # for c in classifiers + # if ( + # c.startswith("Programming Language :: Python :: ") + # and not c.endswith((f":: {python_major}", f":: {python_major} :: Only")) + # ) + # ] + # if python_version_classifiers: + # python_version_classifier_set = set(python_version_classifiers) + # supported_python_version_classifier_set = { + # f"Programming Language :: Python :: {v}" + # for v in supported_python_versions + # } + # if python_version_classifier_set != supported_python_version_classifier_set: + # missing_classifiers = sorted( + # supported_python_version_classifier_set + # - python_version_classifier_set + # ) + # extra_classifiers = sorted( + # python_version_classifier_set + # - supported_python_version_classifier_set + # ) + # if missing_classifiers: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.classifiers' is missing the following classifier(s):\n" + # + "\n".join(f" {c!r}" for c in missing_classifiers) + # ), + # ) + # ] + # if extra_classifiers: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.classifiers' contains extra classifier(s):\n" + # + "\n".join(f" {c!r}" for c in extra_classifiers) + # ), + # ) + # ] return [] From 3703dabe42493af642104945d27a1ef6c3a6cea6 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 23 Jul 2025 00:31:53 +0000 Subject: [PATCH 415/457] [ROCm] delete un-needed workaround for tensor.item() (#158486) Deleting unused workaround per discussion here: https://github.com/pytorch/pytorch/pull/158165#discussion_r2207968880 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158486 Approved by: https://github.com/jeffdaily, https://github.com/houseroad --- aten/src/ATen/native/cuda/CUDAScalar.cu | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 07ada8a0a5f72..0d34bd52f211a 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -11,25 +11,11 @@ #include -#if defined(USE_ROCM) -// TODO(lufang): Tensor.item() on AMD HIP is not synced in the Recsys models. -// This is just a short term workaround. Issue is tracked as FBA-388 on the AMD side. -namespace { - bool use_sync_mode() { - static const bool sync_mode = c10::utils::check_env("HIP_DOUBLE_SYNC_ON_LOCAL_SCALE_DENSE") == true; - return sync_mode; - } -} -#endif - namespace at::native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported"); -#if defined(USE_ROCM) - if (!use_sync_mode()){ -#endif AT_DISPATCH_V2( self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] { // Create pinned memory for the scalar value to avoid implicit @@ -46,15 +32,6 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { at::cuda::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); -#if defined(USE_ROCM) - } else { - auto cpu_self = self.cpu(); - AT_DISPATCH_V2( - self.scalar_type(), "_local_scalar_dense_hip", AT_WRAP([&] { - r = Scalar(*cpu_self.const_data_ptr()); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); - } -#endif return r; } From 39b54b78d73884c0f2daa2826f3d63c352cb5e39 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 23 Jul 2025 00:34:15 +0000 Subject: [PATCH 416/457] [export] runtime asserts for while HOP subgraphs (#158467) Differential Revision: D78431075 For #158366 - Calls runtime asserts pass for HOP subgraphs (in reenter_make_fx) - For while_loop only (can be expanded), clones input tensors for subgraph tracing, so unbacked memos (item, nonzero, etc.) aren't reused Pull Request resolved: https://github.com/pytorch/pytorch/pull/158467 Approved by: https://github.com/ydwu4 --- test/export/test_export.py | 67 ++++++- test/functorch/test_control_flow.py | 190 ++++++++++---------- torch/_dynamo/variables/higher_order_ops.py | 28 ++- torch/_higher_order_ops/utils.py | 8 +- torch/_higher_order_ops/while_loop.py | 23 ++- torch/onnx/_internal/exporter/_fx_passes.py | 9 +- 6 files changed, 203 insertions(+), 122 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 0f436d3af91f0..a28c560f8d8a1 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1814,10 +1814,6 @@ def forward(self, x): ): export(M(), (torch.randn(2, 3),), strict=False) - @testing.expectedFailureTrainingIRToRunDecomp # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) - @testing.expectedFailureTrainingIRToRunDecompNonStrict # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) - @testing.expectedFailureRetraceability # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) - @testing.expectedFailureRetraceabilityNonStrict # Could not guard on data-dependent expression -u0 > 16 (unhinted: -u0 > 16) @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_tensor_constant_idx(self): def while_loop_decomp(x, y0): @@ -7635,6 +7631,69 @@ def forward(self, inputs): ]: self.assertFalse(hasattr(tensor, attr)) + @testing.expectedFailureCppRuntime + def test_while_loop_index_assertions(self): + from torch._higher_order_ops import while_loop + + class Foo(torch.nn.Module): + def forward(self, x): + def cond_fn(idx, acc): + i = idx.item() + return i < x.size(0) + + def body_fn(idx, acc): + # this check_is_size call needs to be traced by this subgraph for the select call, + # it can't be in the cond graph, as that fires & fails right before loop termination. + i = idx.item() + torch._check_is_size(i, max=x.size(0) - 1) + return idx + 1, acc + x[i] + + acc = torch.zeros(x.size(1)) + n = torch.full((), 0, dtype=torch.int64) + _, out = while_loop(cond_fn, body_fn, [n, acc]) + return out + + x = torch.randn(8, 4) + ep = export(Foo(), (x,), strict=False) + self.assertTrue(torch.allclose(x.sum(dim=0), ep.module()(x))) + + @testing.expectedFailureCppRuntime + def test_while_loop_assert_separation(self): + from torch._higher_order_ops import while_loop + + class Bar(torch.nn.Module): + def forward(self, idx, x): + i = idx.item() + + def cond_fn(idx, x): + i = idx.item() + torch._check(i != 5) + return i <= 9 + + def body_fn(idx, x): + i = idx.item() + torch._check(i % 2 == 0) + return idx + 2, x + i + + return while_loop(cond_fn, body_fn, [idx, x + i]) + + inps = (torch.tensor([0]), torch.zeros(1)) + ep = export(Bar(), inps, strict=False) + i, out = ep.module()(*inps) + self.assertEqual(i, 10) + self.assertEqual(out.item(), 20) + + # check assertions are separate for each subgraph + with self.assertRaisesRegex( + RuntimeError, r"Runtime assertion failed for expression Ne\(u[\d]+, 5\).*" + ): + ep.graph_module.while_loop_cond_graph_0(torch.tensor([5]), torch.zeros(1)) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(PythonMod\(u[\d]+, 2\), 0\).*", + ): + ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1)) + def test_constrain_decomp(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 9c65aeedd8d97..54ccd0f7fef25 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -5391,18 +5391,18 @@ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_de self.assertExpectedInline( gm.cond_fn_0.code.strip(), """\ -def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): - sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None +def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None gt = sub > 0; sub = None return gt""", # noqa: B950 ) self.assertExpectedInline( gm.body_fn_0.code.strip(), """\ -def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): - child = l_iter_ - 1; l_iter_ = None - child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None - return (child, child_1)""", # noqa: B950 +def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + child = child_2 - 1; child_2 = None + child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None + return (child, child_4)""", # noqa: B950 ) else: self.assertExpectedInline( @@ -7668,12 +7668,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ self.assertEqual(compiled_out, exp_out) @skipIfTorchDynamo("Skip because we're testing export") - # TODO: we cannot turn on strict=True yet because torch._check for out_it > 0 is - # removed from the graph in dynamo and in non-strict export's graph capturing - # step, we re-run the traced graph module to get graph captured result. - # Since torch._check is removed from graph, we end up getting a data-dependent - # error when we call torch.ones(out_it * 2). - @parametrize("strict", [False]) + @parametrize("strict", [True, False]) @parametrize("dynamic", [True, False]) def test_while_loop_op_int_carry_export(self, strict, dynamic): m, args = WHILE_LOOP_TESTS["int_carry"] @@ -7716,8 +7711,9 @@ def forward(self, x): class while_loop_cond_graph_0(torch.nn.Module): def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): - sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None - lt: "Sym(u0 < s77)" = it_1 < sym_size_int; it_1 = sym_size_int = None + sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None + + lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None return lt class while_loop_body_graph_0(torch.nn.Module): @@ -7757,62 +7753,62 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None - getitem_4: "Sym(u1)" = while_loop[0] + getitem_4: "Sym(u2)" = while_loop[0] - ge: "Sym(u1 >= 1)" = getitem_4 >= 1 - _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 1 on node 'ge'"); ge = _assert_scalar_default = None + ge: "Sym(u2 >= 1)" = getitem_4 >= 1 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'"); ge = _assert_scalar_default = None - gt_1: "Sym(u1 > 0)" = getitem_4 > 0 - _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None + gt_1: "Sym(u2 > 0)" = getitem_4 > 0 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u2 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None - gt: "Sym(u1 > 0)" = getitem_4 > 0 + gt: "Sym(u2 > 0)" = getitem_4 > 0 _check = torch._check(gt); gt = _check = None - add: "Sym(u1 + 1)" = getitem_4 + 1 + add: "Sym(u2 + 1)" = getitem_4 + 1 add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None - lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None + lt: "Sym(u2 < s77)" = getitem_4 < s77; s77 = None - mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None - ones: "f32[2*u1]" = torch.ones(mul); mul = None + mul: "Sym(2*u2)" = getitem_4 * 2; getitem_4 = None + ones: "f32[2*u2]" = torch.ones(mul); mul = None return (add, add_1, lt, ones) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint: "Sym(u0)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - size = l_x_.size(); l_x_ = None + size = child.size(); child = None getitem: "Sym(s77)" = size[0] getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - x_clone: "f32[s77, s27]" = l_x_.clone() + x_clone: "f32[s77, s27]" = child_1.clone() - ge: "Sym(u0 >= 0)" = unbacked_symint >= 0 + ge: "Sym(u1 >= 0)" = unbacked_symint_0 >= 0 _check = torch._check(ge); ge = _check = None - size = l_x_.size(); l_x_ = None + size = child_1.size(); child_1 = None getitem: "Sym(s77)" = size[0] getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None - lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None + lt: "Sym(u1 < s77)" = unbacked_symint_0 < getitem; getitem = None _check_1 = torch._check(lt); lt = _check_1 = None - select: "f32[s27]" = x_clone.select(0, unbacked_symint) - select_1: "f32[s27]" = x_clone.select(0, unbacked_symint) - add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None + select: "f32[s27]" = x_clone.select(0, unbacked_symint_0) + select_1: "f32[s27]" = x_clone.select(0, unbacked_symint_0) + add: "f32[s27]" = select_1 + unbacked_symint_0; select_1 = None copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None - add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None + add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None return (add_1, x_clone) """, # noqa: B950 ) @@ -7916,30 +7912,30 @@ def forward(self, L_t_: "f32[2, 3]"): sum_1: "f32[]" = l_t_.sum() to: "i64[]" = sum_1.to(torch.int64); sum_1 = None item: "Sym(u0)" = to.item(); to = None - child: "f32[2, 3]" = l_t_.sin() + sin: "f32[2, 3]" = l_t_.sin() cond_fn_0 = self.cond_fn_0 body_fn_0 = self.body_fn_0 - while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, child), ()); cond_fn_0 = body_fn_0 = item = child = None - - getitem_8: "Sym(u8)" = while_loop[0] - getitem_9: "Sym(u9)" = while_loop[1] - getitem_10: "Sym(u10)" = while_loop[2] - getitem_11: "Sym(u11)" = while_loop[3] - getitem_12: "Sym(u12)" = while_loop[4] - getitem_13: "Sym(u13)" = while_loop[5] - getitem_14: "Sym(u14)" = while_loop[6] - - child_1: "f32[2, 3]" = while_loop[7]; while_loop = None - - add: "Sym(u8 + 1)" = getitem_8 + 1 - add_1: "Sym(u9 + 1)" = getitem_9 + 1 - add_2: "Sym(u10 + 1)" = getitem_10 + 1 - add_3: "Sym(u11 + 1)" = getitem_11 + 1 - add_4: "Sym(u12 + 1)" = getitem_12 + 1 - add_5: "Sym(u13 + 1)" = getitem_13 + 1 - add_6: "Sym(u14 + 1)" = getitem_14 + 1 - add_7: "f32[2, 3]" = child_1 + 1 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, sin), ()); cond_fn_0 = body_fn_0 = item = sin = None + + getitem_8: "Sym(u15)" = while_loop[0] + getitem_9: "Sym(u16)" = while_loop[1] + getitem_10: "Sym(u17)" = while_loop[2] + getitem_11: "Sym(u18)" = while_loop[3] + getitem_12: "Sym(u19)" = while_loop[4] + getitem_13: "Sym(u20)" = while_loop[5] + getitem_14: "Sym(u21)" = while_loop[6] + + child: "f32[2, 3]" = while_loop[7]; while_loop = None + + add: "Sym(u15 + 1)" = getitem_8 + 1 + add_1: "Sym(u16 + 1)" = getitem_9 + 1 + add_2: "Sym(u17 + 1)" = getitem_10 + 1 + add_3: "Sym(u18 + 1)" = getitem_11 + 1 + add_4: "Sym(u19 + 1)" = getitem_12 + 1 + add_5: "Sym(u20 + 1)" = getitem_13 + 1 + add_6: "Sym(u21 + 1)" = getitem_14 + 1 + add_7: "f32[2, 3]" = child + 1 add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None @@ -7948,7 +7944,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None - add_15: "f32[2, 3]" = child_1 + l_t_; child_1 = l_t_ = None + add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15) class cond_fn_0(torch.nn.Module): @@ -7960,10 +7956,10 @@ def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unba return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"): - add: "Sym(u7 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None - child_1: "f32[2, 3]" = child + 1; child = None - return (unbacked_symint_0, unbacked_symint_1, unbacked_symint_2, unbacked_symint_3, unbacked_symint, 0, add, child_1) + def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"): + add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None + child: "f32[2, 3]" = child_1 + 1; child_1 = None + return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child) """, # noqa: B950 ) @@ -7991,17 +7987,17 @@ def forward(self, x): while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 2, 2, 3, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = sin = None - getitem_6: "Sym(u5)" = while_loop[0] - getitem_7: "Sym(u6)" = while_loop[1] - getitem_8: "Sym(u7)" = while_loop[2] - getitem_9: "Sym(u8)" = while_loop[3] - getitem_10: "Sym(u9)" = while_loop[4] + getitem_6: "Sym(u10)" = while_loop[0] + getitem_7: "Sym(u11)" = while_loop[1] + getitem_8: "Sym(u12)" = while_loop[2] + getitem_9: "Sym(u13)" = while_loop[3] + getitem_10: "Sym(u14)" = while_loop[4] getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None - add: "Sym(u7 + 1)" = getitem_8 + 1 - add_1: "Sym(u8 + 1)" = getitem_9 + 1 - add_2: "Sym(u9 + 1)" = getitem_10 + 1 + add: "Sym(u12 + 1)" = getitem_8 + 1 + add_1: "Sym(u13 + 1)" = getitem_9 + 1 + add_2: "Sym(u14 + 1)" = getitem_10 + 1 add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None @@ -8009,21 +8005,21 @@ def forward(self, x): return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): - mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None - mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None - mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None - lt: "Sym(u17*u18*u19 < u15*u16)" = mul_1 < mul_2; mul_1 = mul_2 = None + def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"): + mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None + mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None + mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None + lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None return lt class while_loop_body_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): - add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None - add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None + def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"): + add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None + add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None - add_2: "Sym(u17 + 1)" = arg2_1 + 1; arg2_1 = None - add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None - add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None + add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None + add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None + add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None return (add, add_1, add_2, add_3, add_4, add_5) @@ -8057,17 +8053,17 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None - getitem_10: "Sym(u5)" = while_loop[0] - getitem_11: "Sym(u6)" = while_loop[1] - getitem_12: "Sym(u7)" = while_loop[2] - getitem_13: "Sym(u8)" = while_loop[3] - getitem_14: "Sym(u9)" = while_loop[4] + getitem_10: "Sym(u10)" = while_loop[0] + getitem_11: "Sym(u11)" = while_loop[1] + getitem_12: "Sym(u12)" = while_loop[2] + getitem_13: "Sym(u13)" = while_loop[3] + getitem_14: "Sym(u14)" = while_loop[4] out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None - add: "Sym(u7 + 1)" = getitem_12 + 1 - add_1: "Sym(u8 + 1)" = getitem_13 + 1 - add_2: "Sym(u9 + 1)" = getitem_14 + 1 + add: "Sym(u12 + 1)" = getitem_12 + 1 + add_1: "Sym(u13 + 1)" = getitem_13 + 1 + add_2: "Sym(u14 + 1)" = getitem_14 + 1 add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None @@ -8075,7 +8071,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8086,19 +8082,19 @@ def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unba return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", unbacked_symint_6: "Sym(u7)", unbacked_symint_7: "Sym(u8)", unbacked_symint_8: "Sym(u9)", child_2: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None - add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None + add: "Sym(u5 + 1)" = unbacked_symint_4 + 1; unbacked_symint_4 = None + add_1: "Sym(u6 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None - add_2: "Sym(u2 + 1)" = unbacked_symint_1 + 1; unbacked_symint_1 = None - add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None - add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None + add_2: "Sym(u7 + 1)" = unbacked_symint_6 + 1; unbacked_symint_6 = None + add_3: "Sym(u8 + 1)" = unbacked_symint_7 + 1; unbacked_symint_7 = None + add_4: "Sym(u9 + 1)" = unbacked_symint_8 + 1; unbacked_symint_8 = None - child_1: "f32[s77, s27]" = child + 1; child = None - return (add, add_1, add_2, add_3, add_4, child_1) + child: "f32[s77, s27]" = child_2 + 1; child_2 = None + return (add, add_1, add_2, add_3, add_4, child) """, # noqa: B950 ) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index fbef41574a4f7..b874cfaadbc46 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1241,8 +1241,28 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: cloned_carry.proxy.node.meta["example_value"].constant = None return cloned_carry - new_operands_seq = [ - unspecialize_carried_inputs(tx, carry) for carry in operands_seq + # clone inputs across subgraphs, to avoid unbacked memoization in fake prop + cond_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if isinstance(carry, TensorVariable) + else carry + ), + ) + for carry in operands_seq + ] + body_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if isinstance(carry, TensorVariable) + else carry + ), + ) + for carry in operands_seq ] # create cond subgrpahs @@ -1253,7 +1273,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: ) = speculate_subgraph( tx, cond_fn, - new_operands_seq + additional_inputs_seq, + cond_operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, @@ -1318,7 +1338,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: ) = speculate_subgraph( tx, body_fn, - new_operands_seq + additional_inputs_seq, + body_operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 580d66551dd42..e00036a8c14ee 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -111,16 +111,22 @@ def _maybe_compile_and_run_fn(fn, *args): def reenter_make_fx(fn): + from torch._guards import detect_fake_mode from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts @functools.wraps(fn) def wrapped(*args): assert _CURRENT_MAKE_FX_TRACER is not None, ( "Cannot reenter make_fx when we're not under a make_fx tracing session" ) - return _CURRENT_MAKE_FX_TRACER.trace_subgraph( + gm = _CURRENT_MAKE_FX_TRACER.trace_subgraph( _maybe_run_with_interpreter(fn), *args ) + if (fake_mode := detect_fake_mode()) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts(gm, fake_mode.shape_env, "reenter_make_fx") + gm.recompile() + return gm return wrapped diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 68a8747ab4b82..16f4606256166 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -302,13 +302,10 @@ def _unspecialize_carried_inputs(x): # - The traced code would use the wrong constant value for all iterations # Solution: We clone the constant tensors and mark the cloned tensor as non-constant so they won't # be specialized to fixed values during tracing body_fn or cond_fn. - elif ( - isinstance(x, torch.Tensor) - and hasattr(x, "constant") - and x.constant is not None - ): + elif isinstance(x, torch.Tensor): x = x.clone() - x.constant = None + if hasattr(x, "constant") and x.constant is not None: + x.constant = None return x with disable_proxy_modes_tracing(): @@ -319,12 +316,14 @@ def _unspecialize_carried_inputs(x): carried_inputs, ) - cond_graph = reenter_make_fx(cond_fn)( - *unspecialized_carried_inputs, *additional_inputs - ) - body_graph = reenter_make_fx(body_fn)( - *unspecialized_carried_inputs, *additional_inputs - ) + def produce_graph(fn): + cloned_carried_inputs = pytree.tree_map_only( + torch.Tensor, lambda x: x.clone(), unspecialized_carried_inputs + ) + return reenter_make_fx(fn)(*cloned_carried_inputs, *additional_inputs) + + cond_graph = produce_graph(cond_fn) + body_graph = produce_graph(body_fn) next_name = None i = 0 diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index a14b25d7cda1e..98359f2ebaff1 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -37,8 +37,9 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph torch.ops.aten._assert_scalar.default, torch.ops.aten._assert_tensor_metadata.default, } - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target in aten_assertion_targets: - graph_module.graph.erase_node(node) - graph_module.recompile() + for gm in graph_module.modules(): + for node in gm.graph.nodes: # type: ignore[union-attr] + if node.op == "call_function" and node.target in aten_assertion_targets: + gm.graph.erase_node(node) # type: ignore[operator, union-attr] + gm.recompile() # type: ignore[operator] return graph_module From 56d07d0bde5507cbd0b1298a372ba2e0d7b969d2 Mon Sep 17 00:00:00 2001 From: rzou Date: Thu, 17 Jul 2025 18:45:37 -0700 Subject: [PATCH 417/457] Add merge_rules category for Dynamo; add guilhermeleobas (#158620) Adds guilhermeleobas to merge_rules for Dynamo and functorch. Guilherme has done good work on both of these subsystems and I am tired of him approving my PRs and me not being able to merge them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158620 Approved by: https://github.com/anijain2305 --- .github/merge_rules.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index f87980ed8df33..17b5f49d9ed73 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -492,6 +492,19 @@ - srossross - chillee - zou3519 + - guilhermeleobas + mandatory_checks_name: + - EasyCLA + - Lint + - pull + +- name: Dynamo + patterns: + - torch/_dynamo/** + - torch/csrc/dynamo/** + - test/dynamo/** + approved_by: + - guilhermeleobas mandatory_checks_name: - EasyCLA - Lint From 096dc35d77643cdf33f30250a3572c446f7d714d Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 22 Jul 2025 13:38:32 -0700 Subject: [PATCH 418/457] [aoti][mps] Fix update constants buffer (#158349) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158349 Approved by: https://github.com/malfet --- test/inductor/test_aot_inductor.py | 68 ++++++++++++++----- .../inductor/aoti_runtime/model_container.h | 27 ++++++-- torch/csrc/inductor/aoti_torch/c/shim_mps.h | 7 ++ torch/csrc/inductor/aoti_torch/shim_mps.mm | 15 ++++ 4 files changed, 93 insertions(+), 24 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index f71c27d92cf82..aa678bea15feb 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -5693,6 +5693,53 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) + def test_update_constant_buffer_simple(self): + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn((3, 3), device=device) + + def forward(self, a): + return a + self.weight + + model = Model(self.device) + a = torch.randn((3, 3), device=self.device) + example_inputs = (a,) + + with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}): + so_path = AOTIRunnerUtil.legacy_compile( + model=model, + example_inputs=example_inputs, + ) + + runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path) + + # Let's check whether the model has correct constant name mapping. + expected_original_fqns = { + "L__self___weight": "L__self___weight", + } + self.assertEqual( + expected_original_fqns, runner.get_constant_names_to_original_fqns() + ) + + test_inputs = torch.randn((3, 3), device=self.device) + new_weight = torch.randn((3, 3), device=self.device) + model.weight = new_weight + attach_weights = {"L__self___weight": new_weight} + runner.update_constant_buffer(attach_weights, False, False, False) + expected = model(test_inputs) + + def runner_call(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, kwargs))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + output = runner_call(test_inputs) + self.assertEqual(expected, output) + def test_update_inactive_constant_buffer(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -6747,6 +6794,8 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # MPS doesn't support float8 "test_fp8": fail_mps(), "test_fp8_view_of_param": fail_mps(), + # unsupported operator: aten._scaled_dot_product_attention_math_for_mps.default + "test_issue_140766": fail_mps(), # Compilation Error "test_fallback_kernel_with_symexpr_output": fail_mps(), "test_while_loop_with_mixed_device": fail_mps(), @@ -6770,29 +6819,12 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_cond_non_tensor_predicates_dynamic_True": fail_mps(), "test_zero_grid_with_unbacked_symbols": fail_mps(), "test_reuse_kernel_dynamic": fail_mps(is_skip=True), - "test_while_loop_with_parameters": fail_mps(is_skip=True), "test_cond_with_parameters": fail_mps(is_skip=True), "test_cond_share_predicte": fail_mps(is_skip=True), - # SetStorage incorrect - "test_small_constant": fail_mps(is_skip=True), - "test_free_inactive_buffer": fail_mps(is_skip=True), - "test_extract_constants_map": fail_mps(is_skip=True), - "test_linear_freezing": fail_mps(is_skip=True), - "test_model_modified_weights": fail_mps(is_skip=True), # Error device may not be nil "test_zero_size_weight": fail_mps(is_skip=True), - # Constants update (segfault) - "test_update_inactive_constant_buffer": fail_mps(is_skip=True), - "test_update_constant_buffer": fail_mps(is_skip=True), - "test_so_without_weight": fail_mps(is_skip=True), - "test_constant_folding_with_update": fail_mps(is_skip=True), - "test_nested_tensor_from_jagged": fail_mps(is_skip=True), - "test_issue_140766": fail_mps(is_skip=True), - "test_buffer_mutation_and_force_mmap_weights": fail_mps(is_skip=True), + # RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0 "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), - "test_large_mmaped_weights": fail_mps(is_skip=True), - "test_subclasses": fail_mps(is_skip=True), - "test_autotune_with_constant_folding": fail_mps(is_skip=True), # MPS doesn't support triton "test_autotuning_args_reuse": fail_mps(), "test_triton_autotuning": fail_mps(), diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 416c186a3ae06..0bd12e841e39f 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -467,14 +467,34 @@ class AOTInductorModelContainer { constants_blob_ptr + constants_internal_offset_[idx]; void* user_constant_ptr; int64_t constant_size; + int64_t* stride; + int64_t offset; aoti_torch_get_data_ptr(tensor, &user_constant_ptr); aoti_torch_get_storage_size(tensor, &constant_size); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); + auto dtype = models_[0]->constant_dtype(idx); + #ifdef USE_XPU sycl::queue* queue_ptr = nullptr; aoti_torch_get_current_sycl_queue((void**)&queue_ptr); queue_ptr ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) .wait(); +#elif USE_MPS + internal_constants_ptr = constants_blob_ptr; + aoti_torch_mps_copy_buffer( + user_constant_ptr, + constants_blob_ptr, + constant_size, + offset, + constants_internal_offset_[idx]); + // For mps tensors, all constants are stored in one buffer, with the + // offset being where the constant starts. So we want to change the + // constant tensor's offset to point to constants_internal_offset_[idx] + offset = constants_internal_offset_[idx] / + aoti_torch_dtype_element_size(dtype); #elif USE_CUDA AOTI_RUNTIME_CUDA_CHECK(cudaMemcpy( internal_constants_ptr, @@ -488,20 +508,15 @@ class AOTInductorModelContainer { // We extract stride and offset from provided Tensor since we do not // guarantee that the tensor is contiguous. AtenTensorHandle tensor_handle; - int64_t* stride; - int64_t offset; int device_type = models_[0]->get_device_type(); int device_idx = models_[0]->get_device_idx(); - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_get_storage_offset(tensor, &offset)); AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( internal_constants_ptr, models_[0]->constant_ndim(idx), models_[0]->constant_shape(idx), stride, offset, - models_[0]->constant_dtype(idx), + dtype, device_type, device_idx, &tensor_handle)); diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mps.h b/torch/csrc/inductor/aoti_torch/c/shim_mps.h index bd86885de13ca..08f1569927f00 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_mps.h @@ -32,6 +32,13 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy( size_t data_size, uint8_t* constants_start); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.mm b/torch/csrc/inductor/aoti_torch/shim_mps.mm index 9f70331ffc0b9..1bf88839ecfe0 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mps.mm +++ b/torch/csrc/inductor/aoti_torch/shim_mps.mm @@ -3,6 +3,8 @@ #include #include #include +#include +#include using namespace torch::aot_inductor; @@ -40,3 +42,16 @@ AOTITorchError aoti_torch_mps_free( memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); }); } + +AOTITorchError +aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + + auto src_mtl_buffer = (id)src_buffer; + auto dst_mtl_buffer = (id)dst_buffer; + + auto* stream = at::mps::getCurrentMPSStream(); + uint64_t profile_id = at::mps::getMPSProfiler().beginProfileCopy(src_mtl_buffer, dst_mtl_buffer, at::OptionalTensorRef(), at::OptionalTensorRef(), data_size, true); + stream->copy_and_sync(src_mtl_buffer, dst_mtl_buffer, data_size, src_offset, dst_offset, true, profile_id); + }); +} From 84058d1179aafe1bd3a25dd6ebde0876841189e4 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 22 Jul 2025 13:38:33 -0700 Subject: [PATCH 419/457] [aoti][mps] Fix cpu kernel generation (#158350) In the case where we have both mps and cpu code which can be inductor compiled, we need to case on the device -- this requires the device field to be correctly passed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158350 Approved by: https://github.com/malfet ghstack dependencies: #158349 --- test/inductor/test_aot_inductor.py | 5 --- torch/_inductor/codegen/cpp_wrapper_mps.py | 37 +++++++++++++++++++--- torch/_inductor/codegen/mps.py | 2 +- torch/_inductor/codegen/wrapper.py | 4 ++- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index aa678bea15feb..271b4f99bbfb7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6777,8 +6777,6 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): } MPS_TEST_FAILURES = { - # Expected supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool - "test_index_put_fallback": fail_mps(), # aten::_embedding_bag is not currently implemented for the MPS device. "test_embedding_bag": fail_mps(), # aten::_embedding_bag is not currently implemented for the MPS device. @@ -6800,12 +6798,9 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_fallback_kernel_with_symexpr_output": fail_mps(), "test_while_loop_with_mixed_device": fail_mps(), "test_while_loop_nested": fail_mps(), - "test_assert_async": fail_mps(), "test_index_put_with_none_index": fail_mps(), "test_size_from_multi_ouptut": fail_mps(), "test_simple_embed_kernel_binary_False": fail_mps(), - "test_while_loop_with_mixed_device_dynamic_False": fail_mps(), - "test_while_loop_with_mixed_device_dynamic_True": fail_mps(), "test_simple_embed_cubin_False": fail_mps(is_skip=True), "test_simple_embed_cubin_True": fail_mps(is_skip=True), "test_simple_embed_kernel_binary_True": fail_mps(), diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index 0b87a0f037953..143141ec4f68a 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -7,11 +7,16 @@ from ..ir import GraphPartitionSignature from ..virtualized import V +from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_gpu import CppWrapperGpu from .wrapper import PythonWrapperCodegen class CppWrapperMps(CppWrapperGpu): + """ + Generates cpp wrapper for running on MPS and calls metal kernels + """ + def __init__(self) -> None: super().__init__() self._used_kernel_names: OrderedSet[str] = OrderedSet() @@ -29,8 +34,15 @@ def _generate_kernel_call_helper( self, kernel_name: str, call_args: list[str], - arg_types: Optional[list[type]] = None, - **kwargs: dict[str, Any], + *, + device: Optional[torch.device] = None, + triton: bool = True, + arg_types: Optional[tuple[Any, ...]] = None, + raw_keys: Optional[tuple[Any, ...]] = None, + raw_args: Optional[tuple[Any, ...]] = None, + triton_meta: Optional[dict[str, Any]] = None, + graph_name: str = "", + original_fxnode_name: Optional[str] = None, ) -> None: """ Generates MPS kernel call code. It should look something like: @@ -46,6 +58,23 @@ def _generate_kernel_call_helper( }); ``` """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + assert device.type == "mps" + assert arg_types is not None new_args = [] @@ -81,9 +110,9 @@ def _generate_kernel_call_helper( "cpp", ) with debug_printer_manager: - self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + self.writeline(self.wrap_mps_kernel_call(kernel_name, new_args)) - def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: + def wrap_mps_kernel_call(self, name: str, call_args: list[str]) -> str: lib_name = name[: -len("_func")] calling_args = " ".join(call_args) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index f8176c191fd48..7060f857828ea 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -929,7 +929,7 @@ def format_threads(threads: list[str], kwarg: str) -> str: wrapper.generate_kernel_call( name, args, - device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now + device=torch.device("mps"), triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0b8ba86c3c185..683282fa9c5ad 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2541,7 +2541,9 @@ def _generate_kernel_call_helper( original_fxnode_name=None, ): device = device or V.graph.get_current_device_or_throw() - if not (triton or device.type != "cpu"): + if not ( + triton or device.type not in ("cpu", "mps") + ): # TODO: Fix me, MPS does not expose streams now self.writeline(self.wrap_kernel_call(kernel_name, call_args)) return From cc372ad557446863f8422f1ca5f415bc78531fa6 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 22 Jul 2025 13:38:34 -0700 Subject: [PATCH 420/457] [aoti][mps] Improve tabbing in cpp generation (#158351) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158351 Approved by: https://github.com/desertfire, https://github.com/malfet ghstack dependencies: #158349, #158350 --- torch/_inductor/codegen/cpp_wrapper_mps.py | 38 ++++++++++------------ 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index 143141ec4f68a..b953927f52be1 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -81,11 +81,11 @@ def _generate_kernel_call_helper( for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( - f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( - f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});" ) else: raise NotImplementedError( @@ -110,28 +110,26 @@ def _generate_kernel_call_helper( "cpp", ) with debug_printer_manager: - self.writeline(self.wrap_mps_kernel_call(kernel_name, new_args)) - - def wrap_mps_kernel_call(self, name: str, call_args: list[str]) -> str: - lib_name = name[: -len("_func")] - calling_args = " ".join(call_args) - - kernel_call_str = "" + self.write_mps_kernel_call(kernel_name, new_args) + def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: # Only add handle definition if the kernel is not already used + lib_name = name[: -len("_func")] if name not in self._used_kernel_names: self._used_kernel_names.add(name) - kernel_call_str += f""" - auto {name} = {lib_name}.getKernelFunction("generated_kernel"); - auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); - """ - kernel_call_str += f""" - {name}->runCommandBlock([&] {{ - {name}->startEncoding(); - {calling_args} - }}); - """ - return kernel_call_str + + self.writeline( + f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");' + ) + self.writeline( + f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());" + ) + + self.writeline(f"{name}->runCommandBlock([&] {{") + self.writeline(f" {name}->startEncoding();") + for call_arg in call_args: + self.writeline(f" {call_arg}") + self.writeline("});") @staticmethod def get_device_include_path(device: str) -> str: From 91602a92548d1dd351979cdc6e778c505c32c2b9 Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 23 Jul 2025 01:21:25 +0000 Subject: [PATCH 421/457] Cleanup old caffe2 scripts (#158475) Testing on this one is grep based: if there were no reference to that script I can find, I deleted. We can easily add any of these back if needed! Pull Request resolved: https://github.com/pytorch/pytorch/pull/158475 Approved by: https://github.com/seemethere, https://github.com/huydhn, https://github.com/cyyever --- .github/workflows/pull.yml | 15 -- scripts/README.md | 39 ----- scripts/add_apache_header.sh | 1 - scripts/apache_header.txt | 15 -- scripts/apache_python.txt | 14 -- scripts/build_android.sh | 189 ----------------------- scripts/build_android_gradle.sh | 102 ------------ scripts/build_host_protoc.sh | 59 ------- scripts/build_ios.sh | 155 ------------------- scripts/build_local.sh | 82 ---------- scripts/build_mobile.sh | 107 ------------- scripts/build_pytorch_android.sh | 51 ------ scripts/build_raspbian.sh | 44 ------ scripts/build_tegra_x1.sh | 51 ------ scripts/build_tizen.sh | 118 -------------- scripts/build_windows.bat | 80 ---------- scripts/diagnose_protobuf.py | 92 ----------- scripts/fbcode-dev-setup/ccache_setup.sh | 92 ----------- scripts/get_python_cmake_flags.py | 24 --- scripts/remove_apache_header.sh | 13 -- scripts/temp.sh | 7 - scripts/xcode_build.rb | 76 --------- 22 files changed, 1426 deletions(-) delete mode 100755 scripts/add_apache_header.sh delete mode 100644 scripts/apache_header.txt delete mode 100644 scripts/apache_python.txt delete mode 100755 scripts/build_android.sh delete mode 100755 scripts/build_android_gradle.sh delete mode 100755 scripts/build_host_protoc.sh delete mode 100755 scripts/build_ios.sh delete mode 100755 scripts/build_local.sh delete mode 100755 scripts/build_mobile.sh delete mode 100755 scripts/build_pytorch_android.sh delete mode 100755 scripts/build_raspbian.sh delete mode 100755 scripts/build_tegra_x1.sh delete mode 100755 scripts/build_tizen.sh delete mode 100644 scripts/build_windows.bat delete mode 100644 scripts/diagnose_protobuf.py delete mode 100755 scripts/fbcode-dev-setup/ccache_setup.sh delete mode 100644 scripts/get_python_cmake_flags.py delete mode 100755 scripts/remove_apache_header.sh delete mode 100755 scripts/temp.sh delete mode 100644 scripts/xcode_build.rb diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 59a7265173800..be0bdc527cc11 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -315,21 +315,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3-clang18-mobile-build: - name: linux-jammy-py3-clang18-mobile-build - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3-clang12-mobile-build - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan - build-generates-artifacts: false - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml diff --git a/scripts/README.md b/scripts/README.md index a1c5ae5f93e67..367e7261f6a60 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1,40 +1 @@ This directory contains the useful tools. - - -## build_android.sh -This script is to build PyTorch/Caffe2 library for Android. Take the following steps to start the build: - -- set ANDROID_NDK to the location of ndk - -```bash -export ANDROID_NDK=YOUR_NDK_PATH -``` - -- run build_android.sh -```bash -#in your PyTorch root directory -bash scripts/build_android.sh -``` -If succeeded, the libraries and headers would be generated to build_android/install directory. You can then copy these files from build_android/install to your Android project for further usage. - -You can also override the cmake flags via command line, e.g., following command will also compile the executable binary files: -```bash -bash scripts/build_android.sh -DBUILD_BINARY=ON -``` - -## build_ios.sh -This script is to build PyTorch/Caffe2 library for iOS, and can only be performed on macOS. Take the following steps to start the build: - -- Install Xcode from App Store, and configure "Command Line Tools" properly on Xcode. -- Install the dependencies: - -```bash -brew install cmake automake libtool -``` - -- run build_ios.sh -```bash -#in your PyTorch root directory -bash scripts/build_ios.sh -``` -If succeeded, the libraries and headers would be generated to build_ios/install directory. You can then copy these files to your Xcode project for further usage. diff --git a/scripts/add_apache_header.sh b/scripts/add_apache_header.sh deleted file mode 100755 index a29a059d2d033..0000000000000 --- a/scripts/add_apache_header.sh +++ /dev/null @@ -1 +0,0 @@ -cat apache_header.txt $1 > _add_apache_header.txt && mv _add_apache_header.txt $1 diff --git a/scripts/apache_header.txt b/scripts/apache_header.txt deleted file mode 100644 index b4eff258eb04d..0000000000000 --- a/scripts/apache_header.txt +++ /dev/null @@ -1,15 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ diff --git a/scripts/apache_python.txt b/scripts/apache_python.txt deleted file mode 100644 index bc104d8845154..0000000000000 --- a/scripts/apache_python.txt +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2016-present, Facebook, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -############################################################################## diff --git a/scripts/build_android.sh b/scripts/build_android.sh deleted file mode 100755 index 43f11b86828d4..0000000000000 --- a/scripts/build_android.sh +++ /dev/null @@ -1,189 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the android target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the Android platform -# using android-cmake. A few notes: -# -# (1) This build also does a host build for protobuf. You will need autoconf -# to carry out this. If autoconf is not possible, you will need to provide -# a pre-built protoc binary that is the same version as the protobuf -# version under third_party. -# If you are building on Mac, you might need to install autotool and -# libtool. The easiest way is via homebrew: -# brew install automake -# brew install libtool -# (2) You will need to have android ndk installed. The current script assumes -# that you set ANDROID_NDK to the location of ndk. -# (3) The toolchain and the build target platform can be specified with the -# cmake arguments below. For more details, check out android-cmake's doc. - -set -e - -# Android specific flags -if [ -z "$ANDROID_ABI" ]; then - ANDROID_ABI="armeabi-v7a with NEON" -fi -ANDROID_NATIVE_API_LEVEL="21" -echo "Build with ANDROID_ABI[$ANDROID_ABI], ANDROID_NATIVE_API_LEVEL[$ANDROID_NATIVE_API_LEVEL]" - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" -if [ -z "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not set; please set it to the Android NDK directory" - exit 1 -fi - -if [ ! -d "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?" - exit 1 -fi - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -ANDROID_NDK_PROPERTIES="$ANDROID_NDK/source.properties" -[ -f "$ANDROID_NDK_PROPERTIES" ] && ANDROID_NDK_VERSION=$(sed -n 's/^Pkg.Revision[^=]*= *\([0-9]*\)\..*$/\1/p' "$ANDROID_NDK_PROPERTIES") - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" -echo "Using Android NDK at $ANDROID_NDK" -echo "Android NDK version: $ANDROID_NDK_VERSION" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use android-cmake to build Android project from CMake. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") - -if [ -z "$BUILD_MOBILE_BENCHMARK" ]; then - BUILD_MOBILE_BENCHMARK=0 -fi - -if [ -z "$BUILD_MOBILE_TEST" ]; then - BUILD_MOBILE_TEST=0 -fi -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 0, build full jit interpreter. -# Default behavior is to build lite interpreter -# cmd: BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK") -CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") -if (( "${ANDROID_NDK_VERSION:-0}" < 18 )); then - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=gcc") -else - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=clang") -fi -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Android specific flags -CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") -CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") -CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") -CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") -if [ "${ANDROID_STL_SHARED:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_STL=c++_shared") -fi -if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") -fi - -if [ -n "${USE_VULKAN}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN=ON") - if [ -n "${USE_VULKAN_FP16_INFERENCE}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_FP16_INFERENCE=ON") - fi - if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") - fi -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=($@) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" third_party/pocketfft/pocketfft_hdronly.h -fi - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further Android project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Android project directory." diff --git a/scripts/build_android_gradle.sh b/scripts/build_android_gradle.sh deleted file mode 100755 index fc27c5dd2516b..0000000000000 --- a/scripts/build_android_gradle.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env bash -set -eux -o pipefail - -env -echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" - -export ANDROID_NDK_HOME=/opt/ndk -export ANDROID_NDK=/opt/ndk -export ANDROID_HOME=/opt/android/sdk - -# Must be in sync with GRADLE_VERSION in docker image for android -# https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh#L155 -export GRADLE_VERSION=6.8.3 -export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION -export GRADLE_PATH=$GRADLE_HOME/bin/gradle - -# touch gradle cache files to prevent expiration -while IFS= read -r -d '' file -do - touch "$file" || true -done < <(find /var/lib/jenkins/.gradle -type f -print0) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h -fi - -export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties -rm -f $GRADLE_LOCAL_PROPERTIES -echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES -echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES -echo "cmake.dir=/usr/local" >> $GRADLE_LOCAL_PROPERTIES - -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Run custom build script -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-custom-build* ]]; then - # Install torch & torchvision - used to download & dump used ops from test model. - retry pip install torch torchvision --progress-bar off - - exec "$(dirname "${BASH_SOURCE[0]}")/../android/build_test_app_custom.sh" armeabi-v7a -fi - -# Run default build -BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include -BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib - -BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include -BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include -BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include -BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib - -PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main - -JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include -mkdir -p $JNI_INCLUDE_DIR - -JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs -mkdir -p $JNI_LIBS_DIR - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 -ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 - -if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 -ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a -fi - -GRADLE_PARAMS="-p android assembleRelease --debug --stacktrace" -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then - GRADLE_PARAMS+=" -PABI_FILTERS=x86" -fi - -if [ -n "${GRADLE_OFFLINE:-}" ]; then - GRADLE_PARAMS+=" --offline" -fi - -$GRADLE_PATH $GRADLE_PARAMS - -find . -type f -name "*.a" -exec ls -lh {} \; - -while IFS= read -r -d '' file -do - echo - echo "$file" - ls -lah "$file" - zipinfo -l "$file" -done < <(find . -type f -name '*.aar' -print0) - -find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz diff --git a/scripts/build_host_protoc.sh b/scripts/build_host_protoc.sh deleted file mode 100755 index cd37db3b31713..0000000000000 --- a/scripts/build_host_protoc.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -############################################################################## -# Build script to build the protoc compiler for the host platform. -############################################################################## -# This script builds the protoc compiler for the host platform, which is needed -# for any cross-compilation as we will need to convert the protobuf source -# files to cc files. -# -# --other-flags accepts flags that should be passed to cmake. Optional. -# -# After the execution of the file, one should be able to find the host protoc -# binary at build_host_protoc/bin/protoc. - -set -e - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_host_protoc"} -mkdir -p $BUILD_ROOT/build -cd $BUILD_ROOT/build - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT") -CMAKE_ARGS+=("-Dprotobuf_BUILD_TESTS=OFF") - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -while true; do - case "$1" in - --other-flags) - shift; - CMAKE_ARGS+=("$@") - break ;; - "") - break ;; - *) - echo "Unknown option passed as argument: $1" - break ;; - esac -done - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ] && [ -d /usr/local/opt/ccache/libexec ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=/usr/local/opt/ccache/libexec/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=/usr/local/opt/ccache/libexec/g++") -fi - -cmake "$CAFFE2_ROOT/third_party/protobuf/cmake" ${CMAKE_ARGS[@]} - -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi -cmake --build . -- "-j${MAX_JOBS}" install diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh deleted file mode 100755 index ad16cb940dcb8..0000000000000 --- a/scripts/build_ios.sh +++ /dev/null @@ -1,155 +0,0 @@ -#!/bin/bash -xe -############################################################################## -# Example command to build the iOS target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the iOS platform -# using ios-cmake. This is very similar to the android-cmake - see -# build_android.sh for more details. - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# bitcode -if [ "${ENABLE_BITCODE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") -fi - -# Use ios-cmake to build iOS project from CMake. -# This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and -# CMAKE_CXX_COMPILER to /usr/bin/g++. In order to use ccache (if it is available) we -# must override these variables via CMake arguments. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$CAFFE2_ROOT/cmake/iOS.cmake") -if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec -fi -if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") -fi - -# IOS_PLATFORM controls type of iOS platform (see ios-cmake) -if [ -n "${IOS_PLATFORM:-}" ]; then - CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") - if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then - # enable bitcode by default for watchos - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") - # disable the QNNPACK - CMAKE_ARGS+=("-DUSE_PYTORCH_QNNPACK=OFF") - fi -else - # IOS_PLATFORM is not set, default to OS, which builds iOS. - CMAKE_ARGS+=("-DIOS_PLATFORM=OS") -fi - -if [ -n "${IOS_ARCH:-}" ]; then - CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") -fi - -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") - -# Don't build binaries or tests (only the library) -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") - -# Metal -if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") -fi - -# Core ML -if [ "${USE_COREML_DELEGATE}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_COREML_DELEGATE=ON") -fi - -# pthreads -CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") -CMAKE_ARGS+=("-DCMAKE_HAVE_THREADS_LIBRARY=1") -CMAKE_ARGS+=("-DCMAKE_USE_PTHREADS_INIT=1") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# enable ARC -CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fobjc-arc") - -# Now, actually build the iOS target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=MinSizeRel \ - -DBUILD_SHARED_LIBS=OFF \ - ${CMAKE_ARGS[@]} \ - $@ - -cmake --build . -- "-j$(sysctl -n hw.ncpu)" - -# copy headers and libs to install directory -echo "Will install headers and libs to $INSTALL_PREFIX for further Xcode project usage." -make install -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Xcode project directory." diff --git a/scripts/build_local.sh b/scripts/build_local.sh deleted file mode 100755 index b843671501256..0000000000000 --- a/scripts/build_local.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -# -############################################################################## -# Example command to build Caffe2 -############################################################################## -# - -set -ex - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -CMAKE_ARGS=() - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ]; then - if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec - fi - if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") - fi -fi - -# Use special install script with Anaconda -if [ -n "${USE_ANACONDA}" ]; then - export SKIP_CONDA_TESTS=1 - export CONDA_INSTALL_LOCALLY=1 - "${ROOT_DIR}/scripts/build_anaconda.sh" "$@" -else - # Make sure that pyyaml is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import yaml' 2>&1)" ]]; then - echo "Installing pyyaml with pip at $(which pip)" - pip install --user pyyaml - fi - - # Make sure that typing is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import typing' 2>&1)" ]]; then - echo "Installing typing with pip at $(which pip)" - pip install --user typing - fi - - # Build protobuf compiler from third_party if configured to do so - if [ -n "${USE_HOST_PROTOC:-}" ]; then - echo "USE_HOST_PROTOC is set; building protoc before building Caffe2..." - "$CAFFE2_ROOT/scripts/build_host_protoc.sh" - CUSTOM_PROTOC_EXECUTABLE="$CAFFE2_ROOT/build_host_protoc/bin/protoc" - echo "Built protoc $("$CUSTOM_PROTOC_EXECUTABLE" --version)" - CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CUSTOM_PROTOC_EXECUTABLE") - fi - - # We are going to build the target into build. - BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} - mkdir -p "$BUILD_ROOT" - cd "$BUILD_ROOT" - echo "Building Caffe2 in: $BUILD_ROOT" - - cmake "$CAFFE2_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" \ - "$@" - - # Determine the number of CPUs to build with. - # If the `CAFFE_MAKE_NCPUS` variable is not specified, use them all. - if [ -n "${MAX_JOBS}" ]; then - CAFFE_MAKE_NCPUS="$MAX_JOBS" - elif [ -n "${CAFFE_MAKE_NCPUS}" ]; then - CAFFE_MAKE_NCPUS="$CAFFE_MAKE_NCPUS" - elif [ "$(uname)" == 'Darwin' ]; then - CAFFE_MAKE_NCPUS="$(sysctl -n hw.ncpu)" - else - CAFFE_MAKE_NCPUS="$(nproc)" - fi - - # Now, actually build the target. - cmake --build . -- "-j$CAFFE_MAKE_NCPUS" -fi diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh deleted file mode 100755 index 7b1995a61ebc7..0000000000000 --- a/scripts/build_mobile.sh +++ /dev/null @@ -1,107 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the mobile target. -############################################################################## -# -# This script shows how one can build a libtorch library optimized for mobile -# devices using host toolchain. - -set -e - -export BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 1, build lite interpreter. -# Default behavior is to build full jit interpreter. -# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_mobile.sh -if [ "x${BUILD_LITE_INTERPRETER}" == "x1" ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -fi -if [ "x${TRACING_BASED}" == "x1" ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi - -# Lightweight dispatch bypasses the PyTorch Dispatcher. -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_ROCM=OFF") -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_BLAS=OFF") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=("$@") - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/build_pytorch_android.sh b/scripts/build_pytorch_android.sh deleted file mode 100755 index 7b80965e34b5c..0000000000000 --- a/scripts/build_pytorch_android.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -set -eux - -############################################################################## -# Master script to build PyTorch Android library with Java bindings. -############################################################################## -# Example usage: -# - Build default AARs: -# scripts/build_pytorch_android.sh -# -# - Build for specific ABI(s): -# scripts/build_pytorch_android.sh armeabi-v7a -# scripts/build_pytorch_android.sh arm64-v8a,x86,x86_64 -# -# Script's workflow: -# 1. Builds libtorch for android for specified android abisi (by default for all 4). -# Custom list of android abis can be specified as a bash argument as comma separated list. -# For example just for testing on android x86 emulator we need only x86 build. -# ./scripts/build_pytorch_android.sh x86 -# 2. Creates symbolic links to android/pytorch_android/src/main/jniLibs/${abi} for libtorch build output, -# android/pytorch_android/src/main/cpp/libtorch_include/${abi} for headers. -# 3. Runs pyotrch_android gradle build: -# gradle assembleRelease - -PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" -PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android - -echo "PYTORCH_DIR:$PYTORCH_DIR" - -source "$PYTORCH_ANDROID_DIR/common.sh" - -check_android_sdk -check_gradle -parse_abis_list "$@" -build_android - -# To set proxy for gradle add following lines to ./gradle/gradle.properties: -# systemProp.http.proxyHost=... -# systemProp.http.proxyPort=8080 -# systemProp.https.proxyHost=... -# systemProp.https.proxyPort=8080 - -if [ "$CUSTOM_ABIS_LIST" = true ]; then - # Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems - # with it when abiFilters are specified. - $GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR assembleRelease -else - $GRADLE_PATH -p $PYTORCH_ANDROID_DIR clean assembleRelease -fi - -find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah diff --git a/scripts/build_raspbian.sh b/scripts/build_raspbian.sh deleted file mode 100755 index b1fe85926219e..0000000000000 --- a/scripts/build_raspbian.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the Raspbian target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for raspbian. The build -# is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain dependencies. -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - libpython-dev \ - python-pip \ - python-numpy \ - protobuf-compiler \ - python-protobuf -# python dependencies -sudo pip install hypothesis - -# Now, actually build the raspbian target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: you can add more dependencies above if you need libraries such as -# leveldb, lmdb, etc. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=hard" \ - || exit 1 - -# Note: while Raspberry pi has 4 cores, running too many builds in parallel may -# cause out of memory errors so we will simply run -j 2 only. -make -j 2 || exit 1 diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh deleted file mode 100755 index 063e17dfe3514..0000000000000 --- a/scripts/build_tegra_x1.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build Caffe2 on Tegra X1. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for NVidia's TX1. -# The build script assumes that you have the most recent libraries installed -# via the JetPack toolkit available at -# https://developer.nvidia.com/embedded/jetpack -# and it assumes that we are starting from a fresh system after the jetpack -# installation. If you have already installed some of the dependencies, you -# may be able to skip quite a few of the apt-get installs. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain necessary dependencies -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo apt-get install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by apt-get is quite old so we install it via pip -sudo pip install hypothesis - -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# CUDA_USE_STATIC_CUDA_RUNTIME needs to be set to off so that opencv can be -# properly used. Otherwise, opencv will complain that opencv_dep_cudart cannot -# be found. -cmake "$CAFFE2_ROOT" -DCUDA_USE_STATIC_CUDA_RUNTIME=OFF \ - || exit 1 - -make -j 4 || exit 1 diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh deleted file mode 100755 index 2262a2503c1d0..0000000000000 --- a/scripts/build_tizen.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env bash -############################################################################## -# Example command to build the Tizen target (RPi3). -############################################################################## -# -# This script shows how one can build a Caffe2 binary for a Tizen device (RPi3). -# The build is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -setup_environment(){ -# The rootfs image for a Tizen target (RPi3)is located at the below webpage: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/tizen-unified_20170529.1/images/ -# If you do not have a Tizen device, Please, run qemu-arm-static and chroot command. -# $ sudo chroot ~/tizen-rootfs qemu-arm-static /usr/bin/bash - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 Tizen into: $BUILD_ROOT" -} - -caffe2_lite_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - make \ - strace \ - cmake \ - gcc* \ - binutils \ - glibc* \ - cpp \ - protobuf-devel \ - libstdc++* -} - -caffe2_lite_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake .. \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_GFLAGS=OFF \ - -DUSE_GLOG=OFF -DUSE_NNPACK=OFF \ - -DRUN_HAVE_STD_REGEX=0 \ - -DRUN_HAVE_POSIX_REGEX=0 \ - -DHAVE_GNU_POSIX_REGEX=0 \ - -DUSE_MPI=OFF -DUSE_OPENMP=OFF \ - -DBUILD_PYTHON=OFF \ - -DUSE_GLOO=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -caffe2_full_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# Obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo zypper install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by zypper is quite old so we install it via pip -sudo pip install hypothesis -} - -caffe2_full_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_CUDA=OFF \ - -DUSE_ITT=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -#### Main -# Setup a build environment to compile Caffe2 deeplearning framework in Tizen platform. -setup_environment -# There are two build options to support 'full' version and 'lite' version (by default). -caffe2_lite_dep_packages -caffe2_lite_build diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat deleted file mode 100644 index 60bfebad08c01..0000000000000 --- a/scripts/build_windows.bat +++ /dev/null @@ -1,80 +0,0 @@ -:: ############################################################################# -:: Example command to build on Windows. -:: ############################################################################# - -:: This script shows how one can build a Caffe2 binary for windows. - -@echo off -setlocal - -SET ORIGINAL_DIR=%cd% -SET CAFFE2_ROOT=%~dp0%.. - -if NOT DEFINED BUILD_BINARY ( - set BUILD_BINARY=OFF -) - -if NOT DEFINED BUILD_SHARED_LIBS ( - :: On CI, we test with BUILD_SHARED_LIBS=OFF. - :: By default, it will be BUILD_SHARED_LIBS=ON. - if NOT DEFINED BUILD_ENVIRONMENT ( - set BUILD_SHARED_LIBS=OFF - ) -) - -if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( - set CAFFE2_STATIC_LINK_CUDA=OFF -) - -if NOT DEFINED CMAKE_BUILD_TYPE ( - set CMAKE_BUILD_TYPE=Release -) - -if NOT DEFINED ONNX_NAMESPACE ( - set ONNX_NAMESPACE=onnx_c2 -) - -if NOT DEFINED TORCH_CUDA_ARCH_LIST ( - set TORCH_CUDA_ARCH_LIST=5.0 -) - -if NOT DEFINED USE_CUDA ( - set USE_CUDA=OFF -) - -if NOT DEFINED USE_OBSERVERS ( - set USE_OBSERVERS=OFF -) - -if NOT DEFINED MSVC_Z7_OVERRIDE ( - set MSVC_Z7_OVERRIDE=OFF -) - -if NOT DEFINED CMAKE_GENERATOR ( - set CMAKE_GENERATOR=Ninja -) - -set CMAKE_VERBOSE_MAKEFILE=1 - -:: Install pyyaml for Aten codegen -pip install pyyaml ninja - -echo CAFFE2_ROOT=%CAFFE2_ROOT% -echo CMAKE_GENERATOR=%CMAKE_GENERATOR% -echo CMAKE_BUILD_TYPE=%CMAKE_BUILD_TYPE% - -:: Set up cmake. We will skip building the test files right now. -pushd %CAFFE2_ROOT% -python tools\build_libtorch.py || goto :label_error -popd - -echo "Caffe2 built successfully" -cd %ORIGINAL_DIR% -endlocal -exit /b 0 - -:label_error -echo "Caffe2 building failed" -cd %ORIGINAL_DIR% -endlocal -exit /b 1 diff --git a/scripts/diagnose_protobuf.py b/scripts/diagnose_protobuf.py deleted file mode 100644 index 65af4618228db..0000000000000 --- a/scripts/diagnose_protobuf.py +++ /dev/null @@ -1,92 +0,0 @@ -## @package diagnose_protobuf -# Module scripts.diagnose_protobuf -"""Diagnoses the current protobuf situation. - -Protocol buffer needs to be properly installed for Caffe2 to work, and -sometimes it is rather tricky. Specifically, we will need to have a -consistent version between C++ and python simultaneously. This is a -convenience script for one to quickly check if this is so on one's local -machine. - -Usage: - [set your environmental variables like PATH and PYTHONPATH] - python scripts/diagnose_protobuf.py -""" - -import os -import re -from subprocess import PIPE, Popen - - -# Get python protobuf version. -try: - import google.protobuf - - python_version = google.protobuf.__version__ - python_protobuf_installed = True -except ImportError: - print("DEBUG: cannot find python protobuf install.") - python_protobuf_installed = False - -if os.name == "nt": - protoc_name = "protoc.exe" -else: - protoc_name = "protoc" - -try: - p = Popen([protoc_name, "--version"], stdout=PIPE, stderr=PIPE) - out, err = p.communicate() -except: - print("DEBUG: did not find protoc binary.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False -else: - if p.returncode: - print("DEBUG: protoc returned a non-zero return code.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False - else: - tmp = re.search(r"\d\.\d\.\d", out) - if tmp: - native_version = tmp.group(0) - native_protobuf_installed = True - else: - print("DEBUG: cannot parse protoc version string.") - print("DEBUG: out: " + out) - native_protobuf_installed = False - -PYTHON_PROTOBUF_NOT_INSTALLED = """ -You have not installed python protobuf. Protobuf is needed to run caffe2. You -can install protobuf via pip or conda (if you are using anaconda python). -""" - -NATIVE_PROTOBUF_NOT_INSTALLED = """ -You have not installed the protoc binary. Protoc is needed to compile Caffe2 -protobuf source files. Depending on the platform you are on, you can install -protobuf via: - (1) Mac: using homebrew and do brew install protobuf. - (2) Linux: use apt and do apt-get install libprotobuf-dev - (3) Windows: install from source, or from the releases here: - https://github.com/google/protobuf/releases/ -""" - -VERSION_MISMATCH = f""" -Your python protobuf is of version {python_version} but your native protoc version is of -version {native_version}. This will cause the installation to produce incompatible -protobuf files. This is bad in general - consider installing the same version. -""" - -# Now, give actual recommendations -if not python_protobuf_installed: - print(PYTHON_PROTOBUF_NOT_INSTALLED) - -if not native_protobuf_installed: - print(NATIVE_PROTOBUF_NOT_INSTALLED) - -if python_protobuf_installed and native_protobuf_installed: - if python_version != native_version: - print(VERSION_MISMATCH) - else: - print("All looks good.") diff --git a/scripts/fbcode-dev-setup/ccache_setup.sh b/scripts/fbcode-dev-setup/ccache_setup.sh deleted file mode 100755 index cb461bee2dd27..0000000000000 --- a/scripts/fbcode-dev-setup/ccache_setup.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash - -# This script installs CCache with CUDA support. -# Example usage: -# ./ccache_setup.sh --path /installed/folder - -set -e -shopt -s expand_aliases - -# Setup the proxy -alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" - -# Parse options -path="$HOME/ccache" -force=false - -while [[ $# -gt 0 ]]; do - case "$1" in - --path) - shift - path="$1" - path=$(realpath "$path") - ;; - --force) # Force install - force=true - ;; - --help) - echo 'usage: ./ccache_setup.py --path /installed/folder [--force]' - exit 0 - ;; - *) - echo "Invalid option: $1" - exit 1 - ;; - esac - shift -done - -# Check whether you put nvcc in PATH -set +e -nvcc_path=$(which nvcc) -if [[ -z "$nvcc_path" ]]; then - nvcc_path="/usr/local/cuda/bin/nvcc" - export PATH="/usr/local/cuda/bin:$PATH" -fi -set -e -if [ ! -f "$nvcc_path" ] && ! $force; then - # shellcheck disable=SC2016 - echo 'nvcc is not detected in $PATH' - exit 1 -fi -echo "nvcc is detected at $nvcc_path" - -if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule - if $CUDA_NVCC_EXECUTABLE --version; then - if ! $force; then - echo "CCache with nvcc support is already installed at $CUDA_NVCC_EXECUTABLE, please add --force" - exit 0 - fi - fi -fi - -# Installing CCache -echo "CCache will be installed at $path" -if [ -e "$path" ]; then - mv --backup=t -T "$path" "${path}.old" -fi - -with_proxy git clone https://github.com/colesbury/ccache.git "$path" -b ccbin -cd "$path" -./autogen.sh -./configure -make install prefix="$path" - -mkdir -p "$path/lib" -mkdir -p "$path/cuda" -ln -sf "$path/bin/ccache" "$path/lib/cc" -ln -sf "$path/bin/ccache" "$path/lib/c++" -ln -sf "$path/bin/ccache" "$path/lib/gcc" -ln -sf "$path/bin/ccache" "$path/lib/g++" -ln -sf "$path/bin/ccache" "$path/cuda/nvcc" -"$path/bin/ccache" -M 25Gi - -# Make sure the nvcc wrapped in CCache is runnable -"$path/cuda/nvcc" --version -echo 'Congrats! The CCache with nvcc support is installed!' -echo -e "Please add the following lines to your bash init script:\\n" -echo "################ Env Var for CCache with CUDA support ################" -# shellcheck disable=SC2016 -echo 'export PATH="'"$path"'/lib:$PATH"' -echo 'export CUDA_NVCC_EXECUTABLE="'"$path"'/cuda/nvcc"' -echo '######################################################################' diff --git a/scripts/get_python_cmake_flags.py b/scripts/get_python_cmake_flags.py deleted file mode 100644 index a49debcc884ad..0000000000000 --- a/scripts/get_python_cmake_flags.py +++ /dev/null @@ -1,24 +0,0 @@ -## @package get_python_cmake_flags -# Module scripts.get_python_cmake_flags -############################################################################## -# Use this script to find your preferred python installation. -############################################################################## -# -# You can use the following to build with your preferred version of python -# if your installation is not being properly detected by CMake. -# -# mkdir -p build && cd build -# cmake $(python ../scripts/get_python_cmake_flags.py) .. -# make -# - - -import sys -import sysconfig - - -flags = [ - f"-DPython_EXECUTABLE:FILEPATH={sys.executable}", -] - -print(" ".join(flags), end="") diff --git a/scripts/remove_apache_header.sh b/scripts/remove_apache_header.sh deleted file mode 100755 index 97980bfbb0ef6..0000000000000 --- a/scripts/remove_apache_header.sh +++ /dev/null @@ -1,13 +0,0 @@ -if [[ "$1" == *.py ]]; then - apache_header="apache_python.txt" -else - apache_header="apache_header.txt" -fi -apache_lines=$(wc -l < "${apache_header}") -apache_md5=$(cat "${apache_header}" | md5) -header_md5=$(head -n ${apache_lines} $1 | md5) -if [ "${header_md5}" == "${apache_md5}" ]; then - keep_lines=$(($(wc -l < $1) - ${apache_lines})) - tail -n ${keep_lines} $1 > _remove_apache_header.txt - mv _remove_apache_header.txt $1 -fi diff --git a/scripts/temp.sh b/scripts/temp.sh deleted file mode 100755 index 18eb2b4733816..0000000000000 --- a/scripts/temp.sh +++ /dev/null @@ -1,7 +0,0 @@ -find ../caffe2 -name "*.py" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.h" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cc" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cpp" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cu" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.mm" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.m" -exec ./remove_apache_header.sh {} \; diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb deleted file mode 100644 index 0734167bdda11..0000000000000 --- a/scripts/xcode_build.rb +++ /dev/null @@ -1,76 +0,0 @@ -require 'optparse' -require 'xcodeproj' - -options = {} -option_parser = OptionParser.new do |opts| - opts.banner = 'Tools for building PyTorch iOS framework on MacOS' - opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| - options[:install] = value - } - opts.on('-x', '--xcodeproj_path ', 'path to the XCode project file') { |value| - options[:xcodeproj] = value - } - opts.on('-p', '--platform ', 'platform for the current build, OS or SIMULATOR') { |value| - options[:platform] = value - } -end.parse! -puts options.inspect - -install_path = File.expand_path(options[:install]) -if not Dir.exist? (install_path) - raise "path don't exist:#{install_path}!" -end -xcodeproj_path = File.expand_path(options[:xcodeproj]) -if not File.exist? (xcodeproj_path) - raise "path don't exist:#{xcodeproj_path}!" -end - -project = Xcodeproj::Project.open(xcodeproj_path) -target = project.targets.first #TestApp -header_search_path = ['$(inherited)', "#{install_path}/include"] -libraries_search_path = ['$(inherited)', "#{install_path}/lib"] -other_linker_flags = ['$(inherited)', "-all_load"] - -target.build_configurations.each do |config| - config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path - config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path - config.build_settings['OTHER_LDFLAGS'] = other_linker_flags - config.build_settings['ENABLE_BITCODE'] = 'No' -end - -# link static libraries -target.frameworks_build_phases.clear -libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libmicrokernels-prod.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a', 'libkineto.a'] -for lib in libs do - path = "#{install_path}/lib/#{lib}" - if File.exist?(path) - libref = project.frameworks_group.new_file(path) - target.frameworks_build_phases.add_file_reference(libref) - end -end -# link system frameworks -frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate', 'UIKit'] -if frameworks - frameworks.each do |framework| - path = "System/Library/Frameworks/#{framework}.framework" - framework_ref = project.frameworks_group.new_reference(path) - framework_ref.name = "#{framework}.framework" - framework_ref.source_tree = 'SDKROOT' - target.frameworks_build_phases.add_file_reference(framework_ref) - end -end -project.save - -sdk = nil -arch = nil -if options[:platform] == 'SIMULATOR' - sdk = 'iphonesimulator' - arch = 'arm64' -elsif options[:platform] == 'OS' - sdk = 'iphoneos' - arch = 'arm64' -else - raise "unsupported platform #{options[:platform]}" -end - -exec "xcodebuild clean build -project #{xcodeproj_path} -alltargets -sdk #{sdk} -configuration Release -arch #{arch}" From 9df0f565972a8a034fd77d65aff2c53e6e9856d1 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 21 Jul 2025 06:26:57 -0700 Subject: [PATCH 422/457] Fix Triton GEMM templates with k=1 (#158650) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thanks to @davidberard98 for much of the analysis here. For GEMMs of K=1, the hints, `tl.multiple_of` and `tl.max_contiguous` apply completely, as the indices to the loads are only dependent on `offs_m` and `offs_n`. For shapes like `(97x1), (1x97)`, this results in misaligned address errors, due to the fact that for all BLOCK_M and BLOCK_N sizes, the last tile is not a contiguous load. With K > 1 case, the hint is not as strict given the dependency on the k indices for the load as well. In the K=1 case, only `offs_m` and `offs_n` are used and broadcasted to the index shape. One can say these hints are "wrong", but in various cases in the hints being wrong, such as with the shape `9999x4, 4x9999`, there is a substantial performance improvement with the hint. For nice shapes with K=1, where M, N are a multiple 8 to where these hints are fine and there is no misaligned address, there is no performance regression observed on H100: Screenshot 2025-07-18 at 5 05 47 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/158650 Approved by: https://github.com/davidberard98 --- test/inductor/test_max_autotune.py | 19 +++++++++++++++++++ torch/_inductor/kernel/mm.py | 4 ++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 096e924a47826..43ed8eda83084 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1551,6 +1551,25 @@ def f(a, b): if "benchmark_gpu" in counter: self.assertEqual(counters["inductor"][counter], 2) + @config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + } + ) + def test_mm_k_1(self): + def mm(x, y): + return x @ y + + for i in range(90, 100): + torch._dynamo.reset() + a = torch.randn((i, 1), device="cuda", dtype=torch.float32) + b = torch.randn((1, i), device="cuda", dtype=torch.float32) + compiled_f = torch.compile(mm) + + out, code = run_and_get_code(compiled_f, a, b) + torch.testing.assert_close(out, mm(a, b), atol=1e-2, rtol=1e-2) + class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index f1c77afd52fd2..951494d6c3d55 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -110,11 +110,11 @@ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: offs_a_m = rm % M - if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N: + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N From dec0d3101c4cb4165bcecd6971fc4ba8ce6dc6ab Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 23 Jul 2025 02:13:09 +0000 Subject: [PATCH 423/457] [export] fix unbacked range deserialization (#158681) Fixes https://github.com/pytorch/pytorch/issues/151809, by reading shape assertion nodes into ShapeEnv, and deferring instantiation of node example values, to be done node-by-node. Differential Revision: D78588406 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158681 Approved by: https://github.com/ydwu4, https://github.com/avikchaudhuri --- test/export/test_export.py | 12 ----- test/export/test_serialize.py | 54 ++++++++++++++++++++++ torch/_export/serde/serialize.py | 77 +++++++++++++++++++++++--------- 3 files changed, 110 insertions(+), 33 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index a28c560f8d8a1..497122e3cc755 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -932,7 +932,6 @@ def forward(self, x): ep = export(f, args, strict=False) self.assertEqual(ep.module()(*args), f(*args)) - @testing.expectedFailureCppSerDes # Cpp Ser/Der seems to fail parsing complicated guards def test_export_statically_known_true(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1588,9 +1587,6 @@ def forward(self): ) self.assertEqual(m(*args), ep.module()(*args)) - @testing.expectedFailureCppSerDes # AssertionError: 0 not in VR[2, int_oo] - @testing.expectedFailureSerDer # AssertionError: 0 not in VR[2, int_oo] - @testing.expectedFailureSerDerNonStrict # AssertionError: 0 not in VR[2, int_oo] def test_cond_access_identical_symint_closure(self): class Example2(torch.nn.Module): def forward(self, x, trigger, target): @@ -5082,7 +5078,6 @@ def forward(self, x): # There should be nonzero view nodes in the graph self.assertTrue(view_count > 0) - @testing.expectedFailureCppSerDes # cpp Ser/Der not handling complicated symbols def test_solver_unsupported_sympy_function(self): # repro of https://github.com/pytorch/pytorch/issues/131897 @@ -8876,10 +8871,6 @@ def forward(self, x): inp = torch.randn(2) self.assertTrue(torch.allclose(ep.module()(inp), torch.nonzero(inp))) - # TODO(pianpwk) blocker: https://github.com/pytorch/pytorch/issues/151809 - @testing.expectedFailureSerDer - @testing.expectedFailureSerDerNonStrict - @testing.expectedFailureCppSerDes def test_redundant_asserts(self): class Foo(torch.nn.Module): def forward(self, x): @@ -13629,9 +13620,6 @@ def forward(self, x, y): ): ep.module()(torch.randn(10), torch.tensor(2)) - @testing.expectedFailureCppSerDes # TODO: When we deserialize we somehow hardcode sympy.lower to 2 - @testing.expectedFailureSerDerNonStrict - @testing.expectedFailureSerDer @torch.fx.experimental._config.patch(backed_size_oblivious=True) def test_baddbmm(self): class M(torch.nn.Module): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 226404737e26c..d174405dd8e06 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -1816,6 +1816,60 @@ def forward(self, x): self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertEqual(counter, 1) + def test_unbacked_range_serdes(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + n = x.item() + torch._check_is_size(n, max=y.size(0) - 1) + return torch.empty(n), y[n] + + ep = torch.export.export( + Foo(), + (torch.tensor([5]), torch.randn(10)), + dynamic_shapes={ + "x": None, + "y": (Dim.DYNAMIC,), + }, + ) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + + # pre-serialize ep + pre_shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in ep.graph.nodes] + ).shape_env + post_shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in loaded_ep.graph.nodes] + ).shape_env + self.assertEqual(pre_shape_env.var_to_range, post_shape_env.var_to_range) + + def test_backed_size_oblivious_serdes(self): + class Foo(torch.nn.Module): + def forward(self, x, y, z): + return x + y + z.item() + + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = torch.export.export( + Foo(), + (torch.randn(1), torch.randn(1), torch.tensor([5])), + dynamic_shapes={ + "x": (Dim.DYNAMIC,), + "y": (Dim.DYNAMIC,), + "z": None, + }, + ) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in loaded_ep.graph.nodes] + ).shape_env + s0 = next(iter(ep.graph.nodes)).meta["val"].size(0) + self.assertEqual(shape_env.var_to_range[s0.node.expr].lower, 0) + if __name__ == "__main__": run_tests() diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 710311d31f6e3..38ccbe287a870 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -224,6 +224,31 @@ class _SerializedProgram: example_inputs: bytes +class LazyMap(dict): + """ + Dictionary class for deferred instantiation of node metadata values. + Purpose is to avoid creation of symbolic-shape tensors before relevant shape guards are parsed. + """ + + def __init__(self): + self.map = {} + self.evaluated = set() + + def __setitem__(self, k, v): + self.map[k] = v + + def __getitem__(self, k): + out = self.map[k] + if k in self.evaluated: + return out + self.evaluated.add(k) + self.map[k] = out() + return self.map[k] + + def __repr__(self): + return self.map.__repr__() + + def deserialize_device(d: Device) -> torch.device: if d.index is None: return torch.device(type=d.type) # type: ignore[call-overload] @@ -1671,7 +1696,7 @@ class Result: def __init__(self) -> None: self.serialized_name_to_node: dict[str, torch.fx.Node] = {} - self.serialized_name_to_meta: dict[str, MetaType] = {} + self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType self.graph = torch.fx.Graph() self.module = torch.nn.Module() @@ -1687,7 +1712,7 @@ def save_graph_module(self) -> Iterator[None]: self.graph = torch.fx.Graph() self.module = torch.nn.Module() self.serialized_name_to_node = {} - self.serialized_name_to_meta = {} + self.serialized_name_to_meta = LazyMap() self.unbacked_symbols: set[sympy.Symbol] = set() try: yield @@ -1876,32 +1901,32 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: # Handle the tensor metas. for name, tensor_value in serialized_graph.tensor_values.items(): log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) - meta_val = self.deserialize_tensor_meta(tensor_value) - log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val) - self.serialized_name_to_meta[name] = meta_val + self.serialized_name_to_meta[name] = ( + lambda v=tensor_value: self.deserialize_tensor_meta(v) + ) for name, sym_int_value in serialized_graph.sym_int_values.items(): log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) - int_val = self.deserialize_sym_int(sym_int_value) - log.debug("[deserialize_sym_int] %s (output): %s", name, int_val) - self.serialized_name_to_meta[name] = int_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_int_value: self.deserialize_sym_int(v) + ) for name, sym_float_value in serialized_graph.sym_float_values.items(): log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) - float_val = self.deserialize_sym_float(sym_float_value) - log.debug("[deserialize_sym_float] %s (output): %s", name, float_val) - self.serialized_name_to_meta[name] = float_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_float_value: self.deserialize_sym_float(v) + ) for name, sym_bool_value in serialized_graph.sym_bool_values.items(): log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) - bool_val = self.deserialize_sym_bool(sym_bool_value) - log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val) - self.serialized_name_to_meta[name] = bool_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_bool_value: self.deserialize_sym_bool(v) + ) for name, script_obj_meta in serialized_graph.custom_obj_values.items(): log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) - self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( - script_obj_meta + self.serialized_name_to_meta[name] = ( + lambda v=script_obj_meta: self.deserialize_script_obj_meta(v) ) log.debug("\n[deserialize graph nodes]") @@ -2080,13 +2105,25 @@ def _is_single_tensor_return(target) -> bool: fx_node.kwargs, fx_node.meta.get("val"), ) + + # handle ShapeEnv asserts + if target == torch.ops.aten._assert_scalar.default: + expr = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(expr, torch.SymBool): + self.shape_env.guard_or_defer_runtime_assert( + expr.node.expr, "", fx_node + ) + elif target == torch.ops.aten.sym_constrain_range_for_size.default: + sym = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(sym, torch.SymInt): + self.shape_env._constrain_range_for_size(sym.node.expr) + + # handle nn_module_stack; serialization throws away empty dicts if ( fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta ): - fx_node.meta[ - "nn_module_stack" - ] = {} # serialization throws away empty dicts + fx_node.meta["nn_module_stack"] = {} def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: log.debug("[deserialize_input_spec] %s", i) @@ -2263,8 +2300,6 @@ def deserialize( if symbol_name_to_range: for k, vr in symbol_name_to_range.items(): lower = vr.lower - if vr.upper >= 2: # max is >= 2, not sym bool range - lower = max(2, lower) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( _int_to_sympy_int(lower, -int_oo), vr.upper ) From 2dccff7dcf56b0d168ebfd7ca08bdeca37273c56 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 23 Jul 2025 02:24:35 +0000 Subject: [PATCH 424/457] [inductor] pass_fds not supported on Windows, skip them on Windows. (#158830) image Almost UTs are failed on `AssertionError: pass_fds not supported on Windows.`, let's skip them on Windows. TODO: I will also debug and confirm `pass_fds` on Windows. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158830 Approved by: https://github.com/jansel --- test/inductor/test_compile_subprocess.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 6eba88ecae970..04297c38bf299 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -18,7 +18,7 @@ from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import TEST_WITH_ASAN +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN from torch.testing._internal.inductor_utils import ( GPU_TYPE, IS_BIG_GPU, @@ -29,6 +29,16 @@ ) +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : Debug and confirm pass_fds status on Windows. + sys.stderr.write( + "Almost UTs failed: pass_fds not supported on Windows, skip them on Windows.\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("pass_fds not supported on Windows") + + # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) From f10e4430e272e40b5d7dbdc4bfa34e6d7a124aa5 Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 23 Jul 2025 02:58:21 +0000 Subject: [PATCH 425/457] [AOTI] normalize path and process model files. (#158705) Continued to https://github.com/pytorch/pytorch/pull/158702 , split `zip_filename_str` and real file path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158705 Approved by: https://github.com/desertfire --- .../aoti_package/model_package_loader.cpp | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index ccaeaa9b775be..022e75e65b776 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -551,27 +551,31 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( std::string so_filename; std::string cpp_filename; std::vector obj_filenames; - std::string model_directory = file_prefix + "data" + k_separator + - "aotinductor" + k_separator + model_name; - std::string const_directory = - file_prefix + "data" + k_separator + "constants"; - - for (const std::string& filename_str : found_filenames) { + std::string model_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "aotinductor" + k_separator + + model_name); + std::string const_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "constants"); + + // zip_filename_str can't be normalize_path_separator, because it should be + // as index for mz_zip_reader_extract_file_to_file. + for (auto zip_filename_str : found_filenames) { + auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory - if (c10::starts_with(filename_str, model_directory) || - c10::starts_with(filename_str, const_directory)) { + if (c10::starts_with(cur_filename, model_directory) || + c10::starts_with(cur_filename, const_directory)) { std::string output_path_str = temp_dir_; - if (c10::starts_with(filename_str, model_directory)) { + if (c10::starts_with(cur_filename, model_directory)) { output_path_str += k_separator; - output_path_str += filename_str; - } else { // startsWith(filename_str, const_directory) + output_path_str += cur_filename; + } else { // startsWith(zip_filename_str, const_directory) // Extract constants to the same directory as the rest of the files // to be consistent with internal implementation - size_t lastSlash = filename_str.find_last_of(k_separator); - std::string filename = filename_str; + size_t lastSlash = cur_filename.find_last_of(k_separator); + std::string filename = cur_filename; if (lastSlash != std::string::npos) { - filename = filename_str.substr(lastSlash + 1); + filename = cur_filename.substr(lastSlash + 1); } output_path_str.append(k_separator) .append(model_directory) @@ -579,10 +583,9 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } - output_path_str = normalize_path_separator(output_path_str); - - LOG(INFO) << "Extract file: " << filename_str << " to " - << output_path_str; + std::string output_file_path = normalize_path_separator(output_path_str); + LOG(INFO) << "Extract file: " << zip_filename_str << " to " + << output_file_path; // Create the parent directory if it doesn't exist size_t parent_path_idx = output_path_str.find_last_of(k_separator); @@ -599,7 +602,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - zip_archive.extract_file(filename_str, output_path_str); + zip_archive.extract_file(zip_filename_str, output_path_str); // Save the file for bookkeeping size_t extension_idx = output_path_str.find_last_of('.'); From b87471e66fb989385483b074b5e5942e8fbbbd8d Mon Sep 17 00:00:00 2001 From: anwang Date: Tue, 22 Jul 2025 09:55:35 -0700 Subject: [PATCH 426/457] [MTIA Aten Backend] Migrate addcdiv.out / addcmul.out / eq.Tensor_out / eq.Scalar_out (#158748) # Context See the first PR https://github.com/pytorch/pytorch/pull/153670 # This diff Migrate addcdiv.out / addcmul.out / eq.Tensor_out / eq.Scalar_out to in-tree. Differential Revision: [D78568103](https://our.internmc.facebook.com/intern/diff/D78568103/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158748 Approved by: https://github.com/albanD, https://github.com/nautsimon --- aten/src/ATen/native/native_functions.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4778aee27f423..0483b0606dde9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8968,7 +8968,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Scalar_out + CPU, CUDA, MTIA: eq_Scalar_out MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -8987,7 +8987,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Tensor_out + CPU, CUDA, MTIA: eq_Tensor_out MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -9380,7 +9380,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcmul_out + CPU, CUDA, MTIA: addcmul_out MPS: addcmul_out_mps tags: pointwise @@ -9401,7 +9401,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcdiv_out + CPU, CUDA, MTIA: addcdiv_out MPS: addcdiv_out_mps tags: pointwise From 42a69f7c2b11cc4c6a28424c6e0ea3ca8e9a0b5f Mon Sep 17 00:00:00 2001 From: anwang Date: Tue, 22 Jul 2025 09:55:37 -0700 Subject: [PATCH 427/457] [MTIA Aten Backend] Migrate addmm.out / baddbmm.out / bmm.out (#158749) # Context See the first PR https://github.com/pytorch/pytorch/pull/153670 # This diff Migrate addmm.out / baddbmm.out / bmm.out to in-tree. Differential Revision: [D78578483](https://our.internmc.facebook.com/intern/diff/D78578483/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158749 Approved by: https://github.com/albanD, https://github.com/nautsimon ghstack dependencies: #158748 --- aten/src/ATen/native/native_functions.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0483b0606dde9..f3f3e0d582e57 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1067,6 +1067,7 @@ CUDA: baddbmm_out_cuda MPS: baddbmm_out_mps XPU: baddbmm_out_xpu + MTIA: baddbmm_out_mtia SparseCsrCUDA: baddbmm_out_sparse_csr_cuda - func: baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -1376,6 +1377,7 @@ CUDA: bmm_out_cuda MPS: bmm_out_mps XPU: bmm_out_xpu + MTIA: bmm_out_mtia SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda @@ -7065,6 +7067,7 @@ CUDA: addmm_out_cuda MPS: addmm_out_mps XPU: addmm_out_xpu + MTIA: addmm_out_mtia SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda SparseCsrCPU: addmm_out_sparse_compressed_cpu From f80f97d192253336940c67fd9bf6004ff8711088 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 23 Jul 2025 04:39:47 +0000 Subject: [PATCH 428/457] [audio hash update] update the pinned audio hash (#158807) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158807 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index a2d5ddd38cec7..b49cbe79f9d71 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -00b0c91db92c51a11356249262577b9fa26c18c5 +b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e From be72bcf828b536e0d81359a37c0f150b69fce5d4 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 23 Jul 2025 04:41:49 +0000 Subject: [PATCH 429/457] [vllm hash update] update the pinned vllm hash (#158806) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vllm hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158806 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vllm.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index 22adf465e471c..07270e9f557b1 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -29d1ffc5b4c763ef76aff9e3f617fa60dd292418 +b77c7d327f2a463bb9ef8be36f30e920bc066502 From a6b7bea2448e03bd5c6e876f92de752c3a616646 Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Tue, 22 Jul 2025 13:34:32 -0700 Subject: [PATCH 430/457] [inductor] support linear & layer_norm unbacked (#155267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What - Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path. - For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`). ### Example DDE ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes) ``` ``` torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0) Caused by: (_inductor/ir.py:2797 in create) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267 Approved by: https://github.com/eellison --- test/inductor/test_unbacked_symints.py | 31 ++++++++++++++++++++++++++ torch/_decomp/decompositions.py | 4 ++-- torch/_inductor/ir.py | 4 ++-- torch/_inductor/lowering.py | 30 ++++++++----------------- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 8e9df9e03c84f..cf132bea84a58 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -515,6 +515,37 @@ def fn(x): x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device) torch.compile(fn, fullgraph=True)(x) + @skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton") + @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) + def test_unbacked_linear_layer_norm_input(self, device): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(387, 128, bias=True, device=device) + self.layer_norm1 = torch.nn.LayerNorm(387, device=device) + self.layer_norm2 = torch.nn.LayerNorm(128, device=device) + + def forward(self, x, mask): + masked_select = x.masked_select(mask) + view = masked_select.view(-1, 387) + + linear = self.linear(view) + layer_norm1 = self.layer_norm1(view) + layer_norm2 = self.layer_norm2(linear) + return linear, layer_norm1, layer_norm2 + + model = MyModel() + inputs = ( + torch.randn((256, 387), dtype=torch.float, device=device), + torch.randint( + low=0, high=2, size=(256, 1), dtype=torch.bool, device=device + ), + ) + + actual = torch.compile(model, fullgraph=True)(*inputs) + expected = model(*inputs) + torch.testing.assert_close(actual, expected) + instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 832928ebf8aee..634c4b6b49545 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1667,9 +1667,9 @@ def native_layer_norm_backward( N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import statically_known_true - if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + if statically_known_true(M == 0) or statically_known_true(N == 0): return ( input.new_zeros(input_shape) if output_mask[0] else None, input.new_zeros(input_shape[axis:]) if output_mask[1] else None, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 3ddfdc4be768c..a21b9c50938e4 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2876,7 +2876,7 @@ def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLik assert old_size[i] is not None new_size[i] = old_size[i] elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(old_size[i], 1), size_oblivious=True + sympy.Eq(old_size[i], 1), fallback_value=False ): pass else: @@ -2903,7 +2903,7 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: new_stride.append( stride if not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(size, 1), size_oblivious=True + sympy.Eq(size, 1), fallback_value=False ) else sympy.S.Zero ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index c4c8f70003c60..503795bc513c1 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b): output = [] for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): if V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(y, 1), size_oblivious=True + sympy.Eq(y, 1), fallback_value=False ): output.append(x) elif V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(x, 1), size_oblivious=True + sympy.Eq(x, 1), fallback_value=False ): output.append(y) else: @@ -939,26 +939,14 @@ def broadcast_tensors(*inputs): outputs = [] for x in inputs: sizes = x.get_size() - if len(sizes) != len(target) or any( - ( - ( - V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(a, 1), size_oblivious=True - ) - and not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(b, 1), size_oblivious=True - ) - ) - or ( - not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(a, 1), size_oblivious=True - ) - and V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(b, 1), size_oblivious=True - ) - ) + + def is_length_one(size: sympy.Expr): + return V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), fallback_value=False ) - for a, b in zip(sizes, target) + + if len(sizes) != len(target) or any( + is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target) ): x = expand(x, target) outputs.append(x) From 1d302eaee80e15d6d011749f70b3f18c2218ae84 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Wed, 23 Jul 2025 05:42:40 +0000 Subject: [PATCH 431/457] [vllm] add vllm test base docker image (#158755) # description Add base docker image for vllm. It seems like we use the base docker image for both pytorch build, and tests. Configure a base image for vllm against pytorch CI. # Others Added readme regarding how the base docker images are used, and how to add one, this also explain what is the right file to modify Pull Request resolved: https://github.com/pytorch/pytorch/pull/158755 Approved by: https://github.com/seemethere, https://github.com/huydhn --- .ci/docker/README.md | 102 ++++++++++++++++++++++++++++ .ci/docker/build.sh | 13 +++- .github/workflows/docker-builds.yml | 1 + 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/.ci/docker/README.md b/.ci/docker/README.md index 15779155933e1..0fd4ed7ca502c 100644 --- a/.ci/docker/README.md +++ b/.ci/docker/README.md @@ -36,3 +36,105 @@ See `build.sh` for valid build environments (it's the giant switch). # Set flags (see build.sh) and build image sudo bash -c 'TRITON=1 ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest ``` + +## [Guidance] Adding a New Base Docker Image + +### Background + +The base Docker images in directory `.ci/docker/` are built by the `docker-builds.yml` workflow. Those images are used throughout the PyTorch CI/CD pipeline. You should only create or modify a base Docker image if you need specific environment changes or dependencies before building PyTorch on CI. + +1. **Automatic Rebuilding**: + - The Docker image building process is triggered automatically when changes are made to files in the `.ci/docker/*` directory + - This ensures all images stay up-to-date with the latest dependencies and configurations + +2. **Image Reuse in PyTorch Build Workflows** (example: linux-build): + - The images generated by `docker-builds.yml` are reused in `_linux-build.yml` through the `calculate-docker-image` step + - The `_linux-build.yml` workflow: + - Pulls the Docker image determined by the `calculate-docker-image` step + - Runs a Docker container with that image + - Executes `.ci/pytorch/build.sh` inside the container to build PyTorch + +3. **Usage in Test Workflows** (example: linux-test): + - The same Docker images are also used in `_linux-test.yml` for running tests + - The `_linux-test.yml` workflow follows a similar pattern: + - It uses the `calculate-docker-image` step to determine which Docker image to use + - It pulls the Docker image and runs a container with that image + - It installs the wheels from the artifacts generated by PyTorch build jobs + - It executes test scripts (like `.ci/pytorch/test.sh` or `.ci/pytorch/multigpu-test.sh`) inside the container + +### Understanding File Purposes + +#### `.ci/docker/build.sh` vs `.ci/pytorch/build.sh` +- **`.ci/docker/build.sh`**: + - Used for building base Docker images + - Executed by the `docker-builds.yml` workflow to pre-build Docker images for CI + - Contains configurations for different Docker build environments + +- **`.ci/pytorch/build.sh`**: + - Used for building PyTorch inside a Docker container + - Called by workflows like `_linux-build.yml` after the Docker container is started + - Builds PyTorch wheels and other artifacts + +#### `.ci/docker/ci_commit_pins/` vs `.github/ci_commit_pins` +- **`.ci/docker/ci_commit_pins/`**: + - Used for pinning dependency versions during base Docker image building + - Ensures consistent environments for building PyTorch + - Changes here trigger base Docker image rebuilds + +- **`.github/ci_commit_pins`**: + - Used for pinning dependency versions during PyTorch building and tests + - Ensures consistent dependencies for PyTorch across different builds + - Used by build scripts running inside Docker containers + +### Step-by-Step Guide for Adding a New Base Docker Image + +#### 1. Add Pinned Commits (If Applicable) + +We use pinned commits for build stability. The `nightly.yml` workflow checks and updates pinned commits for certain repository dependencies daily. + +If your new Docker image needs a library installed from a specific pinned commit or built from source: + +1. Add the repository you want to track in `nightly.yml` and `merge-rules.yml` +2. Add the initial pinned commit in `.ci/docker/ci_commit_pins/`. The text filename should match the one defined in step 1 + +#### 2. Configure the Base Docker Image +1. **Add new Base Docker image configuration** (if applicable): + + Add the configuration in `.ci/docker/build.sh`. For example: + ```bash + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-new1) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + NEW_ARG_1=yes + ;; + ``` + +2. **Add build arguments to Docker build command**: + + If you're introducing a new argument to the Docker build, make sure to add it in the Docker build step in `.ci/docker/build.sh`: + ```bash + docker build \ + .... + --build-arg "NEW_ARG_1=${NEW_ARG_1}" + ``` + +3. **Update Dockerfile logic**: + + Update the Dockerfile to use the new argument. For example, in `ubuntu/Dockerfile`: + ```dockerfile + ARG NEW_ARG_1 + # Set up environment for NEW_ARG_1 + RUN if [ -n "${NEW_ARG_1}" ]; then bash ./do_something.sh; fi + ``` + +4. **Add the Docker configuration** in `.github/workflows/docker-builds.yml`: + + The `docker-builds.yml` workflow pre-builds the Docker images whenever changes occur in the `.ci/docker/` directory. This includes the + pinned commit updates. diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d6cba6659db7a..cf022d099326b 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -160,6 +160,17 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + ;; pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.6 CUDNN_VERSION=9 @@ -276,7 +287,7 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 VISION=yes diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 4678779443b98..255e36ebfffa7 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -50,6 +50,7 @@ jobs: runner: [linux.12xlarge] docker-image-name: [ pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, From 255a04baf11f2a999386632271f13ae4e7d3836d Mon Sep 17 00:00:00 2001 From: Ruben Rodriguez Buchillon Date: Wed, 23 Jul 2025 06:44:27 +0000 Subject: [PATCH 432/457] [pt2 event logging] send autotuning data for strides and hinted shapes (#158852) Summary: # Why capture relevant data for offline lookup table generation # What report the hinted sizes not just the symbolic sizes Test Plan: ``` buck2 run mode/opt scripts/coconutruben/torchmm:experiment 2>&1 | tee /tmp/epx040 ``` This only validates that this change does not break anything, as the schema is not on scuba yet (not actualized) Rollback Plan: Reviewed By: stashuk-olek Differential Revision: D77837548 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158852 Approved by: https://github.com/jingsh --- torch/_inductor/select_algorithm.py | 54 +++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index c316f0d4bc7ef..903d616bb91eb 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2337,20 +2337,7 @@ def autotune(choices, hint_override: Optional[int] = None): f"{name}_template_autotuning", log_pt2_compile_event=True, dynamo_compile_column_us="compile_time_autotune_time_us", - metadata={ - "autotune_strides": ", ".join( - [str(n.get_stride()) for n in input_nodes] - ), - "autotune_dtypes": ", ".join( - [str(n.get_dtype()) for n in input_nodes] - ), - "autotune_shape": ", ".join( - ["x".join(map(str, n.get_size())) for n in input_nodes] - ), - "autotune_offset": ", ".join( - [str(n.get_layout().offset) for n in input_nodes] - ), - }, + metadata=_autotune_metadata(input_nodes), ): return benchmark(choices, hint_override=hint_override) @@ -3370,5 +3357,44 @@ def sympy_call(self, *args, **kwargs): return self.fn(*args, **kwargs, **self.kwargs_sym) +def _autotune_metadata(input_nodes): + """Helper function to extract autotune metadata from input nodes.""" + return { + "autotune_strides": ", ".join([str(n.get_stride()) for n in input_nodes]), + "autotune_dtypes": ", ".join([str(n.get_dtype()) for n in input_nodes]), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join([str(n.get_layout().offset) for n in input_nodes]), + # TODO(coconutruben): replace this with taking KernelInputs as the + # argument, and extracting those out there directly + "autotune_strides_hinted": ", ".join( + [ + str( + V.graph.sizevars.size_hints( + n.get_stride(), + fallback=config.unbacked_symint_fallback, + ) + ) + for n in input_nodes + ] + ), + "autotune_shape_hinted": ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), + fallback=config.unbacked_symint_fallback, + ), + ) + ) + for n in input_nodes + ] + ), + } + + # ensure lowering is imported so that `extern_kernels.*` is populated from . import lowering # noqa: F401 From c665594c1edca9a507b0ec8b18ab74a0ecb65bc3 Mon Sep 17 00:00:00 2001 From: "Han, Xu" Date: Wed, 23 Jul 2025 08:00:14 +0000 Subject: [PATCH 433/457] [AOTI] fix extract file failed on Windows. (#158702) Changes: 1. rename zip index filename, and keep it out of normalize path. 2. normalize output path for extract file. Extract files successful: image Pull Request resolved: https://github.com/pytorch/pytorch/pull/158702 Approved by: https://github.com/angelayi --- .../aoti_package/model_package_loader.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 022e75e65b776..b835b1a00821e 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -559,7 +559,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( // zip_filename_str can't be normalize_path_separator, because it should be // as index for mz_zip_reader_extract_file_to_file. - for (auto zip_filename_str : found_filenames) { + for (auto const& zip_filename_str : found_filenames) { auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory if (c10::starts_with(cur_filename, model_directory) || @@ -588,12 +588,12 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << output_file_path; // Create the parent directory if it doesn't exist - size_t parent_path_idx = output_path_str.find_last_of(k_separator); + size_t parent_path_idx = output_file_path.find_last_of(k_separator); if (parent_path_idx == std::string::npos) { throw std::runtime_error( - "Failed to find parent path in " + output_path_str); + "Failed to find parent path in " + output_file_path); } - std::string parent_path = output_path_str.substr(0, parent_path_idx); + std::string parent_path = output_file_path.substr(0, parent_path_idx); if (!recursive_mkdir(parent_path)) { throw std::runtime_error(fmt::format( "Failed to create directory {}: {}", @@ -605,15 +605,15 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( zip_archive.extract_file(zip_filename_str, output_path_str); // Save the file for bookkeeping - size_t extension_idx = output_path_str.find_last_of('.'); + size_t extension_idx = output_file_path.find_last_of('.'); if (extension_idx != std::string::npos) { - std::string filename_extension = output_path_str.substr(extension_idx); + std::string filename_extension = output_file_path.substr(extension_idx); if (filename_extension == ".cpp") { - cpp_filename = output_path_str; + cpp_filename = output_file_path; } else if (filename_extension == object_file_ext()) { - obj_filenames.push_back(output_path_str); + obj_filenames.push_back(output_file_path); } else if (filename_extension == extension_file_ext()) { - so_filename = output_path_str; + so_filename = output_file_path; } } } From ee72338f0ca91df825306cb9f780b0274c07e9ae Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Wed, 23 Jul 2025 13:19:11 +0000 Subject: [PATCH 434/457] [Inductor] MSVC use pointer when generating temporary array pointer (#158913) MSVC cannot implicitly convert a const iterator to a const pointer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158913 Approved by: https://github.com/desertfire Co-authored-by: Xu Han --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 56d6f40dade81..fa880d35366ce 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -21,7 +21,7 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import config, ir +from .. import config, cpp_builder, ir from ..utils import ( _align, aoti_model_name_from_config, @@ -119,7 +119,12 @@ def _generate_temporary_array_pointer( # e.g. const double** is possible, but not const double* const*. This means # that an array containing pointers must _already_ be properly const-qualified # by the c_type, and not add additional const-ness. - ptr_call = "data()" if force_mutable or c_type.endswith("*") else "cbegin()" + # MSVC does not support implicitly converting a const iterator to a const pointer. + ptr_call = ( + "data()" + if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl() + else "cbegin()" + ) return ( f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" ) From 57024913c409764f129d6a7792625f5b05462e31 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Wed, 23 Jul 2025 13:31:17 +0000 Subject: [PATCH 435/457] Fix decorators skipping NCCL tests (#158846) Avoid failures caused by tests exiting via sys.exit instead of `unittest.skip` In particular it will not try to start the test (causing forks into subprocess) just to stop them (killing the subprocess) which is done in the test setup Using `unittest.skip` decorators avoids the starting of the test in the first place. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158846 Approved by: https://github.com/Skylion007 --- test/distributed/test_functional_api.py | 26 ++---- torch/testing/_internal/common_distributed.py | 80 ++++--------------- .../_shard/sharded_tensor/__init__.py | 6 +- .../distributed/_tensor/common_dtensor.py | 7 +- .../_internal/distributed/distributed_test.py | 18 ++--- 5 files changed, 37 insertions(+), 100 deletions(-) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 3b93e4d2b19ad..61f52b2dc60ab 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -13,6 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_distributed import exit_if_lt_x_gpu from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU @@ -25,7 +26,7 @@ DistributedTestBase, MultiThreadedTestCase, requires_nccl, - TEST_SKIPS, + skip_if_no_gpu, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -476,26 +477,14 @@ def allred_mesh_dim(input): BACKEND = dist.Backend.HCCL -# allows you to check for multiple accelerator irrespective of device type -# to add new device types to this check simply follow the same format -# and append an elif with the conditional and appropriate device count function for your new device -def exit_if_lt_x_accelerators(x): - if TEST_CUDA: - if torch.cuda.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) - elif TEST_HPU: - if torch.hpu.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code) - - def with_comms(func=None): if func is None: return partial(with_comms) @wraps(func) def wrapper(self, *args, **kwargs): - if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: - sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + if BACKEND == dist.Backend.NCCL: + exit_if_lt_x_gpu(self.world_size) kwargs["device"] = DEVICE self.pg = self.create_pg(device=DEVICE) @@ -508,9 +497,9 @@ def wrapper(self, *args, **kwargs): class TestCollectivesWithDistributedBackend(DistributedTestBase): + @skip_if_no_gpu @with_comms() def test_all_gather_into_tensor_coalesced(self, device): - exit_if_lt_x_accelerators(self.world_size) tensors = [ torch.ones([4], device=device), torch.ones([4], device=device) + 1, @@ -582,9 +571,8 @@ def allreduce(t, pg): compiled_allreduce(torch.randn(8, device=device), self.pg) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @skip_if_no_gpu def test_tracing_with_fakepg(self, device=DEVICE): - exit_if_lt_x_accelerators(self.world_size) - def allreduce(t, pg): return ft_c.all_reduce(t, "sum", pg) @@ -626,9 +614,9 @@ class TestDistributedBackendCollectivesWithWorldSize4( def world_size(self): return 4 + @skip_if_no_gpu @with_comms() def test_permute_tensor_with_sub_group(self, device): - exit_if_lt_x_accelerators(self.world_size) mesh_dim_names = ["dp", "tp"] mesh_2d = dt.init_device_mesh( diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9b311411e34a2..c13d7f3c42c14 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -118,14 +118,17 @@ def requires_ddp_rank(device): return device in DDP_RANK_DEVICES +def exit_if_lt_x_gpu(x): + if torch.cuda.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + def skip_if_no_gpu(func): """Skips if the world size exceeds the number of GPUs, ensuring that if the test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" @wraps(func) def wrapper(*args, **kwargs): - if not (TEST_CUDA or TEST_HPU or TEST_XPU): - sys.exit(TEST_SKIPS["no_cuda"].exit_code) world_size = int(os.environ["WORLD_SIZE"]) if TEST_CUDA and torch.cuda.device_count() < world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) @@ -136,7 +139,9 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) - return wrapper + return unittest.skipUnless( + TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message + )(wrapper) # TODO (kwen2501): what is the purpose of this decorator? Tests with this @@ -168,33 +173,16 @@ def wrapper(*args, **kwargs): def require_n_gpus_for_nccl_backend(n, backend): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if backend == "nccl" and torch.cuda.device_count() < n: - sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) - else: - return func(*args, **kwargs) - - return wrapper - - return decorator + return skip_if_lt_x_gpu(n) if backend == "nccl" else unittest.skipIf(False, None) def import_transformers_or_skip(): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - try: - from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 - - return func(*args, **kwargs) - except ImportError: - sys.exit(TEST_SKIPS["importerror"].exit_code) + try: + from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 - return wrapper - - return decorator + return unittest.skipIf(False) + except ImportError: + return unittest.skip(TEST_SKIPS["importerror"].message) def at_least_x_gpu(x): @@ -208,36 +196,7 @@ def at_least_x_gpu(x): def skip_if_lt_x_gpu(x): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: - return func(*args, **kwargs) - if TEST_HPU and torch.hpu.device_count() >= x: - return func(*args, **kwargs) - if TEST_XPU and torch.xpu.device_count() >= x: - return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) - - return wrapper - - return decorator - - -# This decorator helps avoiding initializing cuda while testing other backends -def nccl_skip_if_lt_x_gpu(backend, x): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if backend != "nccl": - return func(*args, **kwargs) - if torch.cuda.is_available() and torch.cuda.device_count() >= x: - return func(*args, **kwargs) - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) - - return wrapper - - return decorator + return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message) def verify_ddp_error_logged(model_DDP, err_substr): @@ -413,14 +372,7 @@ def requires_multicast_support(): def skip_if_rocm_multiprocess(func): """Skips a test for ROCm""" func.skip_if_rocm_multiprocess = True - - @wraps(func) - def wrapper(*args, **kwargs): - if not TEST_WITH_ROCM: - return func(*args, **kwargs) - sys.exit(TEST_SKIPS["skipIfRocm"].exit_code) - - return wrapper + return unittest.skipUnless(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func) def skip_if_win32(): diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index 60c744ac1a84c..8b52acdbeeb04 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -7,8 +7,8 @@ import torch.distributed as dist from torch.distributed import rpc from torch.testing._internal.common_distributed import ( + exit_if_lt_x_gpu, MultiProcessTestCase, - TEST_SKIPS, tp_transports, ) @@ -94,8 +94,8 @@ def with_comms(func=None, init_rpc=True, backend="nccl"): @wraps(func) def wrapper(self, *args, **kwargs): - if backend == "nccl" and torch.cuda.device_count() < self.world_size: - sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + if backend == "nccl": + exit_if_lt_x_gpu(self.world_size) self.init_comms(init_rpc=init_rpc, backend=backend) func(self, *args, **kwargs) self.destroy_comms(destroy_rpc=init_rpc) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 94bfead8a0c03..f84d326ae3bf6 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -3,7 +3,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import itertools -import sys from collections.abc import Iterator, Sequence from dataclasses import dataclass from functools import partial, wraps @@ -31,11 +30,11 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( + exit_if_lt_x_gpu, MultiProcessTestCase, MultiThreadedTestCase, run_subtests, skip_if_lt_x_gpu, - TEST_SKIPS, ) from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec @@ -356,8 +355,8 @@ def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init) -> None: - if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: - sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + if "nccl" in self.backend: + exit_if_lt_x_gpu(self.world_size) if self.backend not in [ "nccl", diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28b761a37d58c..c2ff09d9297f1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -59,10 +59,10 @@ captured_output, cleanup_temp_dir, DistTestCases, + exit_if_lt_x_gpu, init_multigpu_helper, initialize_temp_directories, MultiProcessTestCase, - nccl_skip_if_lt_x_gpu, require_n_gpus_for_nccl_backend, requires_nccl_version, simple_sparse_reduce_tests, @@ -601,10 +601,8 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): self.rank = rank self.file_name = file_name - if torch.cuda.is_available() and torch.cuda.device_count() < int( - self.world_size - ): - sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + if torch.cuda.is_available(): + exit_if_lt_x_gpu(int(self.world_size)) try: pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout) timeout = timedelta(seconds=pg_timeout_seconds) @@ -5336,7 +5334,7 @@ def step_model(model, input, target): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @nccl_skip_if_lt_x_gpu(BACKEND, 2) + @require_n_gpus_for_nccl_backend(2, BACKEND) def test_accumulate_gradients_no_sync(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5347,7 +5345,7 @@ def test_accumulate_gradients_no_sync(self): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @nccl_skip_if_lt_x_gpu(BACKEND, 2) + @require_n_gpus_for_nccl_backend(2, BACKEND) def test_accumulate_gradients_no_sync_grad_is_view(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5358,7 +5356,7 @@ def test_accumulate_gradients_no_sync_grad_is_view(self): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @nccl_skip_if_lt_x_gpu(BACKEND, 2) + @require_n_gpus_for_nccl_backend(2, BACKEND) def test_accumulate_gradients_no_sync_allreduce_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync @@ -5386,7 +5384,7 @@ def allreduce_hook( BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @nccl_skip_if_lt_x_gpu(BACKEND, 2) + @require_n_gpus_for_nccl_backend(2, BACKEND) def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce @@ -5420,7 +5418,7 @@ def div(fut): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @nccl_skip_if_lt_x_gpu(BACKEND, 2) + @require_n_gpus_for_nccl_backend(2, BACKEND) def test_get_future(self): def mult(fut): return [t * 3 for t in fut.wait()] From 5998cd4eaaf50d5a427f0b0ec14f2135e4a46723 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 22 Jul 2025 22:16:04 -0700 Subject: [PATCH 436/457] [MPS] Speedup torch.full for 1-byte types (#158874) By using [`fillBuffer:range:value:`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/fillbuffer:range:value:?language=objc) rather than MPSGraph op, which should be faster and also does not have INT_MAX limit Which in turn fixes `test_index_put_accumulate_large_tensor_mps` test Pull Request resolved: https://github.com/pytorch/pytorch/pull/158874 Approved by: https://github.com/dcci --- .../ATen/native/mps/operations/ConstantOps.mm | 33 +++++++++++-------- test/test_indexing.py | 2 -- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 644cb80c1e44f..e36ac4dc45246 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -62,15 +62,12 @@ return self; } -// returns false if tensor cannot be filled with fillBuffer() -static bool fill_mps_tensor_(Tensor& self, uint8_t value) { - if (self.is_contiguous()) { - MPSStream* stream = getCurrentMPSStream(); - auto storage_byte_offset = self.storage_offset() * self.itemsize(); - stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); - return true; - } - return false; +static Tensor& fill_mps_tensor_(Tensor& self, uint8_t value) { + TORCH_INTERNAL_ASSERT(self.is_contiguous()); + const auto stream = getCurrentMPSStream(); + auto storage_byte_offset = self.storage_offset() * self.itemsize(); + stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); + return self; } Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { @@ -89,8 +86,20 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) { return self; } // check if it's possible to use fillBuffer() to fill the Tensor's storage - if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) - return self; + if (self.is_contiguous()) { + if (value.toDouble() == 0.0) { + return fill_mps_tensor_(self, 0); + } + if (self.scalar_type() == kBool) { + return fill_mps_tensor_(self, value.toBool()); + } + if (self.scalar_type() == kByte) { + return fill_mps_tensor_(self, value.toByte()); + } + if (self.scalar_type() == kChar) { + return fill_mps_tensor_(self, value.toChar()); + } + } return fill_scalar_mps_impl(self, value); } @@ -101,8 +110,6 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) { value.dim(), " dimensions."); Scalar scalar_value = value.item(); - if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) - return self; return fill_scalar_mps(self, scalar_value); } diff --git a/test/test_indexing.py b/test/test_indexing.py index 37a12f00ab272..3870734f60d34 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -998,8 +998,6 @@ def test_byte_mask_accumulate(self, device): ) @serialTest(TEST_CUDA) def test_index_put_accumulate_large_tensor(self, device): - if device.startswith("mps"): - raise unittest.SkipTest("Crash with max number of dimentions") # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 dt = torch.int8 From d898d0d437bfdc0719e6c69d5005606c5e64fca8 Mon Sep 17 00:00:00 2001 From: James Wu Date: Tue, 22 Jul 2025 13:22:40 -0700 Subject: [PATCH 437/457] [Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847) This PR addresses a few small bugfixes needed to make NanoGPT inference work, and also adds a new `--caching-precompile` argument to torchbench. With `--caching-precompile`, after every benchmark we save precompile artifacts to DynamoCache, allowing us to test caching precompile on all existing benchmarks. The following bugfixes are in this PR to make all of this work: - Fix global variables being pruned with DUPLICATE_INPUT guards. DUPLICATE_INPUT guards have additional vars from the second input, which we track with additional_local_vars, but we never tracked additional global variables. This fixes the issue. (See torch/_dynamo/guards.py changes) - Return None from PRecompileContext.serialize() if no new dynamo compiles occurred. There's no reason to save artifacts (i.e. autotuning artifacts, etc) if no dynamo_compile occurred, so we return None early. We may later want to support editing existing dynamo artifacts as a TODO, but that's upcoming. - log `dynamo_start` on CompilePackage.load: This is only needed so that tlparse doesn't ignore TORCH_TRACE logs generated when caching precompile hits. If there are no actual compiles, we never log a "dynamo_start" entry, which makes internal tlparse ignore the TORCH_TRACE file. ## Test Plan After this PR, the following now works: ``` TORCH_LOGS=dynamo tlp python benchmarks/dynamo/torchbench.py --only nanogpt --performance --inference --backend inductor --caching-precompile --warm-start-latency ``` tlparse result (internal): Cold Start (6 seconds): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_vk9nkp4m.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Warm Start (~1 s): https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpAWe0zD/dedicated_log_torch_trace_5l4iwrpm.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 The 1 second of warm start here can be improved: the costs here are mostly in starting up workers and triton and initializing CUDA, a lot of which should not be included in the compile time cost in real world scenarios where these are already loaded before training begins. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158847 Approved by: https://github.com/zhxchen17 --- benchmarks/dynamo/common.py | 38 ++++++++++++++++++++- torch/_dynamo/config.py | 2 +- torch/_dynamo/convert_frame.py | 51 ++++++++++++++++------------- torch/_dynamo/eval_frame.py | 3 +- torch/_dynamo/guards.py | 5 ++- torch/_dynamo/package.py | 3 +- torch/_dynamo/precompile_context.py | 3 ++ 7 files changed, 76 insertions(+), 29 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 900a93c552b46..69ed64d8489a6 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3264,6 +3264,12 @@ def get_example_inputs(self): instead of deleting it and creating a new one.", ) + parser.add_argument( + "--caching-precompile", + action="store_true", + help="Enables caching precompile, serializing artifacts to DynamoCache between runs", + ) + group_latency = parser.add_mutually_exclusive_group() group_latency.add_argument( "--cold-start-latency", @@ -3414,6 +3420,29 @@ def get_example_inputs(self): return parser.parse_args(args) +def process_caching_precompile(): + """ + After every process_entry, save precompile artifacts to DynamoCache + """ + assert torch._dynamo.config.caching_precompile, ( + "Caching precompile should be enabled with --caching-precompile" + ) + from torch._dynamo.precompile_context import PrecompileContext + + # Serialize all callables, clear PrecompileContext + # TODO: put this under torch.compiler API once ready + serialized = PrecompileContext.serialize() + PrecompileContext.clear() + if serialized is not None: + artifacts, info = serialized + print( + f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..." + ) + results = PrecompileContext.deserialize(artifacts) + assert results is not None + PrecompileContext.populate_caches(results) + + def process_entry(rank, runner, original_dir, args): args.rank = rank with maybe_init_distributed( @@ -3422,7 +3451,10 @@ def process_entry(rank, runner, original_dir, args): world_size=args.world_size, port=args.distributed_master_port, ): - return run(runner, args, original_dir) + result = run(runner, args, original_dir) + if args.caching_precompile: + process_caching_precompile() + return result def maybe_fresh_cache(args): @@ -3458,6 +3490,10 @@ def main(runner, original_dir=None, args=None): ) with maybe_fresh_cache(args): + if args.caching_precompile: + os.environ["TORCH_CACHING_PRECOMPILE"] = "1" + torch._dynamo.config.caching_precompile = True + args.init_distributed = args.only and args.multiprocess if args.init_distributed: # NB: Do NOT query device count before CUDA initialization; we're diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 7ef748b85f3e3..adfd2ab4f00e8 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -549,7 +549,7 @@ def default_debug_dir_root() -> str: # Experimental feature for running automatic caching precompile. # Enables automatic DynamoCache save/load -caching_precompile = False +caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1" # Enables the Compiled Autograd engine to trace autograd calls made under torch.compile(). # Note: AOTAutograd will still trace and partition an AOT backward graph local to that diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 149a1c400d99a..bba4d9c980869 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,6 +225,31 @@ def fx_forward_from_src_skip_result( return result +def log_dynamo_start(code: CodeType, skip: int = 0) -> None: + convert_frame_intern = structured.intern_string(__file__) + # Initialize the ChromiumEventLogger on start + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, + ) + + def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ Context manager to: @@ -1135,28 +1160,7 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - convert_frame_intern = structured.intern_string(__file__) - # Initialize the ChromiumEventLogger on start - torch._logging.trace_structured( - "dynamo_start", - lambda: { - "stack": list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) - + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] - }, - ) + log_dynamo_start(code, skip) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1588,9 +1592,10 @@ def __call__( with compile_lock, _disable_current_modes(): # skip=1: skip this frame - return self._torchdynamo_orig_backend( + result = self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) + return result def catch_errors_wrapper( diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index f47ca4185bed0..bfe6801fc4b5d 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -679,8 +679,7 @@ def get_compiler_config() -> Any: # If self._package is lazily initialized, we should check the dynamo cache now if config.caching_precompile: - assert self._package is not None - if not self._package.is_initialized(): + if self._package is not None and not self._package.is_initialized(): result = DynamoCache.load(fn) if result is None: # Create a fresh CompilePackage diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d7fe7cc300455..7b1203bae265d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1969,6 +1969,8 @@ def DUPLICATE_INPUT(self, guard, source_b): if self.serialization_mode == "save": if name := get_local_source_name(source_b): self.check_fn_manager.additional_used_local_vars.add(name) + if name := get_global_source_name(source_b): + self.check_fn_manager.additional_used_global_vars.add(name) ref_a = self.arg_ref(guard) ref_b = self.arg_ref(source_b.name()) @@ -2848,6 +2850,7 @@ def __init__( self.guards_serialization_mode = guards_serialization_mode self.used_builtin_vars: OrderedSet[str] = OrderedSet() self.additional_used_local_vars: OrderedSet[str] = OrderedSet() + self.additional_used_global_vars: OrderedSet[str] = OrderedSet() if runtime_global_scope: assert self.guards_serialization_mode == "load" self.runtime_global_scope = runtime_global_scope @@ -3038,7 +3041,7 @@ def _ref(x): global_scope_state = { k: v for k, v in output_graph_guards_state.global_scope.items() - if k in used_global_vars + if k in used_global_vars or k in self.additional_used_global_vars } global_scope_state[builtins_dict_name] = { k: v diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index be750d41a1dc9..a466267035596 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -380,7 +380,7 @@ def install(self, backends: dict[_BackendId, Any]) -> None: 3. Install the precompiled cache entries to ExtraStates on the code object. """ from torch._C._dynamo.eval_frame import _load_precompile_entry - from torch._dynamo.convert_frame import get_compile_id + from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start from torch._guards import compile_context, CompileContext from .output_graph import get_builtins_dict @@ -394,6 +394,7 @@ def install(self, backends: dict[_BackendId, Any]) -> None: # collapsed into 0/0, 1/0 on warm. increment_frame() compile_id = get_compile_id(frame_state={}) + log_dynamo_start(code) with ( compile_context(CompileContext(compile_id)), dynamo_timed( diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 040f54ce70db2..31d858fe3fc33 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -141,6 +141,9 @@ def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: @classmethod def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: cls._save_artifacts_by_type() + # No need to serialize if there are no new dynamo compiles + if "precompile_dynamo" not in cls._new_cache_artifacts: + return None return super().serialize() @staticmethod From 2a60b8fc97cf4fbb408221b5e8cb0ad683f78b04 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 23 Jul 2025 15:36:14 +0000 Subject: [PATCH 438/457] [export][ez] Fix packaging (#158855) Summary: as title, seems ytpo Test Plan: CI Rollback Plan: Differential Revision: D78758466 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158855 Approved by: https://github.com/henryoier --- torch/export/pt2_archive/_package.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 83cae4836d9a5..9d3be9758a7c2 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -308,7 +308,7 @@ def _package_exported_programs( return if isinstance(exported_programs, ExportedProgram): - exported_programs = {"model", exported_programs} # type: ignore[assignment] + exported_programs = {"model": exported_programs} assert isinstance(exported_programs, dict) From 7d296d5c19750cecd82e2b95f6fb0f8dd918282e Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 22 Jul 2025 13:38:35 -0700 Subject: [PATCH 439/457] [aoti][mps] Enable more tests (#158703) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158703 Approved by: https://github.com/malfet, https://github.com/desertfire ghstack dependencies: #158349, #158350, #158351 --- test/inductor/test_aot_inductor.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 271b4f99bbfb7..a96a11da93bd9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -6794,17 +6794,14 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_fp8_view_of_param": fail_mps(), # unsupported operator: aten._scaled_dot_product_attention_math_for_mps.default "test_issue_140766": fail_mps(), - # Compilation Error + # cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t' "test_fallback_kernel_with_symexpr_output": fail_mps(), - "test_while_loop_with_mixed_device": fail_mps(), + # while-loop subgraph calls same kernel as outside. need to figure out how to + # either (1) tell outside to initialize a new kernel or (2) generate + # subgraph as a separate function, which would(?) cause (1) to happen automatically. "test_while_loop_nested": fail_mps(), + # correctness issue "test_index_put_with_none_index": fail_mps(), - "test_size_from_multi_ouptut": fail_mps(), - "test_simple_embed_kernel_binary_False": fail_mps(), - "test_simple_embed_cubin_False": fail_mps(is_skip=True), - "test_simple_embed_cubin_True": fail_mps(is_skip=True), - "test_simple_embed_kernel_binary_True": fail_mps(), - "test_missing_cubin": fail_mps(), # Dynamism "test_shifted_constraint_ranges": fail_mps(), "test_while_loop_with_sym_expr_cond_dynamic_True": fail_mps(), From d3d9bc1c312cb8415d504a7af5682e75a97d3541 Mon Sep 17 00:00:00 2001 From: Mwiza Kunda Date: Wed, 23 Jul 2025 15:56:06 +0000 Subject: [PATCH 440/457] [inductor] Allow backends to register their own custom config object (#158254) An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254 Approved by: https://github.com/eellison --- test/inductor/custom_inductor_config.py | 15 +++++++ test/inductor/test_codecache.py | 50 +++++++++++++++++++++++ torch/_inductor/codecache.py | 8 ++++ torch/_inductor/codegen/common.py | 19 +++++++++ torch/testing/_internal/inductor_utils.py | 12 ++++-- 5 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 test/inductor/custom_inductor_config.py diff --git a/test/inductor/custom_inductor_config.py b/test/inductor/custom_inductor_config.py new file mode 100644 index 0000000000000..e29430728f946 --- /dev/null +++ b/test/inductor/custom_inductor_config.py @@ -0,0 +1,15 @@ +# Owner(s): ["module: inductor"] + +# This module is used in test_codecache.py to verify the correctness +# of FXGraphHashDetails when a custom inductor backend registers its own +# config object + +import sys + +from torch.utils._config_module import install_config_module + + +enable_optimisation: bool = False + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 51af64153500d..93545ed93cc3b 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -66,6 +66,12 @@ from torch.testing._internal.triton_utils import requires_cuda +try: + from . import custom_inductor_config +except ImportError: + import custom_inductor_config + + if HAS_TRITON: import triton # @manual @@ -2463,6 +2469,50 @@ def uuid(self) -> Optional[Union[bytes, str]]: pickler.dumps(details3), ) + def test_hash_custom_backend_config(self): + """ + Test cache correctness when a custom inductor codegen config + is installed + """ + with patch_inductor_backend( + "cpu", custom_backend_config=custom_inductor_config + ): + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + self.assertEqual(pickler.dumps(details1), pickler.dumps(details2)) + + custom_inductor_config.enable_optimisation = True + details3 = FxGraphHashDetails(None, [], {}, []) + self.assertNotEqual(pickler.dumps(details2), pickler.dumps(details3)) + + torch._dynamo.reset() + counters.clear() + + custom_inductor_config.enable_optimisation = False + x = torch.zeros(32) + y = torch.zeros(32) + compiled_fn = torch.compile(torch.add) + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + torch._dynamo.reset() + counters.clear() + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + torch._dynamo.reset() + counters.clear() + + # Changing the custom config should trigger a recompilation + custom_inductor_config.enable_optimisation = True + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + def test_bypass_unsupported(self): """ Test _reduce_unsupported diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c8b23aded15c2..442d36e0d117e 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -52,6 +52,7 @@ from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.common import ( + custom_backend_codegen_configs, custom_backend_passes, init_backend_registration, ) @@ -854,6 +855,13 @@ def __init__( map(self._get_custom_pass_detail, custom_backend_passes.values()) ) + # Save custom inductor codegen configs + self.custom_backend_codegen_configs = { + device: custom_config.save_config_portable(ignore_private_configs=False) + for device, custom_config in custom_backend_codegen_configs.items() + if custom_config is not None + } + # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 828050d6da140..92ee9e28be74e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,6 +34,7 @@ import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +from torch.utils._config_module import ConfigModule from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter @@ -367,6 +368,7 @@ def cpp_global_scratch( device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} +custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. @@ -396,11 +398,20 @@ def register_backend_for_device( device_wrapper_codegen: WrapperConstructor, device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, device_custom_pass: Optional[CustomGraphModulePass] = None, + device_custom_config: Optional[ConfigModule] = None, ) -> None: device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) custom_backend_passes[device] = device_custom_pass + if device_custom_config: + assert ( + isinstance(device_custom_config, ConfigModule) + and device_custom_config is not config + ), ( + f"{device_custom_config=} cannot be the same as the default inductor config {config=}" + ) + custom_backend_codegen_configs[device] = device_custom_config class BackendFeature(Enum): @@ -463,6 +474,14 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul return custom_backend_passes[device] if device in custom_backend_passes else None +def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]: + return ( + custom_backend_codegen_configs[device] + if device in custom_backend_codegen_configs + else None + ) + + @functools.cache def init_backend_registration() -> None: from .cpp import CppScheduling diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 91a4aaa5728a8..8a521d56f5f84 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -16,6 +16,7 @@ from torch._inductor.codecache import CppCodeCache from torch._inductor.custom_graph_pass import CustomGraphModulePass from torch._inductor.codegen.common import ( + get_custom_backend_config_for_device, get_custom_backend_pass_for_device, get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -27,6 +28,7 @@ from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu from torch.utils._helion import has_helion from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -308,7 +310,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): def patch_inductor_backend( device: str, python_wrapper_codegen: PythonWrapperCodegen = None, - custom_pass: CustomGraphModulePass = None + custom_pass: CustomGraphModulePass = None, + custom_backend_config: ConfigModule = None ): """ Patch the inductor backend for a specific device. @@ -321,6 +324,7 @@ def patch_inductor_backend( original_python_wrapper = get_wrapper_codegen_for_device(device, False) original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) original_custom_pass = get_custom_backend_pass_for_device(device) + original_custom_backend_config = get_custom_backend_config_for_device(device) try: # Register modified backend for the device @@ -329,7 +333,8 @@ def patch_inductor_backend( original_scheduling, python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, original_cpp_wrapper, - custom_pass if custom_pass is not None else original_custom_pass + custom_pass if custom_pass is not None else original_custom_pass, + custom_backend_config if custom_backend_config is not None else original_custom_backend_config ) yield finally: @@ -339,5 +344,6 @@ def patch_inductor_backend( original_scheduling, original_python_wrapper, original_cpp_wrapper, - original_custom_pass + original_custom_pass, + original_custom_backend_config ) From 671e22a9513604bb2a7cc886218a89bad7f8b3e6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 22 Jul 2025 14:32:04 -0300 Subject: [PATCH 441/457] [math] Raise exception in Dynamo if constant fold call fail (#156975) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156975 Approved by: https://github.com/zou3519 --- ...math-IsCloseTests.test_negative_tolerances | 0 .../CPython313-test_math-MathTests.testAcos | 0 .../CPython313-test_math-MathTests.testAsin | 0 .../CPython313-test_math-MathTests.testAsinh | 0 .../CPython313-test_math-MathTests.testAtan | 0 .../CPython313-test_math-MathTests.testAtan2 | 0 ...Python313-test_math-MathTests.testCopysign | 0 .../CPython313-test_math-MathTests.testCosh | 0 ...CPython313-test_math-MathTests.testDegrees | 0 .../CPython313-test_math-MathTests.testExp | 0 .../CPython313-test_math-MathTests.testFabs | 0 ...est_math-MathTests.testFactorialHugeInputs | 0 .../CPython313-test_math-MathTests.testFmod | 0 .../CPython313-test_math-MathTests.testFrexp | 0 .../CPython313-test_math-MathTests.testLdexp | 0 .../CPython313-test_math-MathTests.testLog10 | 0 .../CPython313-test_math-MathTests.testLog1p | 0 .../CPython313-test_math-MathTests.testModf | 0 .../CPython313-test_math-MathTests.testPow | 0 ...CPython313-test_math-MathTests.testRadians | 0 .../CPython313-test_math-MathTests.testSin | 0 .../CPython313-test_math-MathTests.testSqrt | 0 .../CPython313-test_math-MathTests.testTanh | 0 ...hon313-test_math-MathTests.test_exceptions | 0 ...-test_math-MathTests.test_input_exceptions | 0 ...13-test_math-MathTests.test_math_dist_leak | 0 ...thon313-test_math-MathTests.test_nextafter | 0 torch/_dynamo/variables/torch.py | 21 ++++++++++++------- 28 files changed, 14 insertions(+), 7 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index fc1f9646ffdf4..9a83acd61b1d4 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -52,7 +52,7 @@ tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces -from ..exc import unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -1394,12 +1394,19 @@ def patched_fn(*args, **kwargs): source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold - return ConstantVariable.create( - self.as_python_constant()( - *[x.as_python_constant() for x in args], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ), - ) + try: + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + except (OverflowError, TypeError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) if self.is_tensor_method(): name = self.value.__name__ From f5314f89c89d0794ae7cdb3a29bb915e2db27035 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 22 Jul 2025 14:32:04 -0300 Subject: [PATCH 442/457] [struct] Add `struct.pack` and `struct.unpack` polyfills (#156977) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156977 Approved by: https://github.com/XuehaiPan, https://github.com/jansel ghstack dependencies: #156975 --- ...matTestCase.test_double_specials_do_unpack | 0 ...rmatTestCase.test_float_specials_do_unpack | 0 ...matTestCase.test_serialized_float_rounding | 0 .../CPython313-test_math-MathTests.testAcosh | 0 .../CPython313-test_math-MathTests.testAtanh | 0 .../CPython313-test_math-MathTests.testCbrt | 0 .../CPython313-test_math-MathTests.testCos | 0 .../CPython313-test_math-MathTests.testExp2 | 0 .../CPython313-test_math-MathTests.testLog | 0 .../CPython313-test_math-MathTests.testSinh | 0 .../CPython313-test_math-MathTests.testTan | 0 torch/_dynamo/polyfills/__init__.py | 1 + torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/struct.py | 27 +++++++++++++++++++ 14 files changed, 29 insertions(+) delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh delete mode 100644 test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan create mode 100644 torch/_dynamo/polyfills/struct.py diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 936d1c62e9e6e..db2493f26caf0 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -30,6 +30,7 @@ operator as operator, os as os, pytree as pytree, + struct as struct, sys as sys, ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index f60aa57a5d409..f306d47ba5f8a 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -19,6 +19,7 @@ "operator", "os", "pytree", + "struct", "sys", "fx", "tensor", diff --git a/torch/_dynamo/polyfills/struct.py b/torch/_dynamo/polyfills/struct.py new file mode 100644 index 0000000000000..f4522a12f7323 --- /dev/null +++ b/torch/_dynamo/polyfills/struct.py @@ -0,0 +1,27 @@ +""" +Python polyfills for struct +""" + +from __future__ import annotations + +import struct +from typing import Any +from typing_extensions import Buffer + +from ..decorators import substitute_in_graph + + +__all__ = [ + "pack", + "unpack", +] + + +@substitute_in_graph(struct.pack, can_constant_fold_through=True) # type: ignore[arg-type] +def pack(fmt: bytes | str, /, *v: Any) -> bytes: + return struct.pack(fmt, *v) + + +@substitute_in_graph(struct.unpack, can_constant_fold_through=True) # type: ignore[arg-type] +def unpack(format: bytes | str, buffer: Buffer, /) -> tuple[Any, ...]: + return struct.unpack(format, buffer) From 576253c47603baff6709353631e92e8da7d8d7dd Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 22 Jul 2025 14:32:05 -0300 Subject: [PATCH 443/457] [math] Trace `float.fromhex` (#156976) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156976 Approved by: https://github.com/zou3519 ghstack dependencies: #156975, #156977 --- ...hon313-test_float-HexFloatTestCase.test_from_hex | 0 ...-test_float-HexFloatTestCase.test_invalid_inputs | 0 ...n313-test_float-HexFloatTestCase.test_whitespace | 0 torch/_dynamo/variables/builder.py | 8 ++++++++ torch/_dynamo/variables/builtin.py | 13 +++++++++++++ 5 files changed, 21 insertions(+) delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs delete mode 100644 test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9c13267c25bf3..0862d9da83110 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -132,6 +132,7 @@ get_locals_to_steal, get_static_address_type, is_frozen_dataclass, + is_function, is_function_or_wrapper, is_invoke_subgraph, is_lru_cache_wrapped_function, @@ -161,6 +162,7 @@ VariableTracker, VariableTrackerMeta, ) +from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .ctx_manager import ( AutocastModeVariable, @@ -1223,6 +1225,12 @@ def build_key_value(i, k, v): ) and BuiltinMethodVariable.is_supported_builtin_method(value): self.install_guards(GuardBuilder.ID_MATCH) return BuiltinMethodVariable(value, source=self.source) + elif is_function(value) and value in (float.fromhex, float.hex): + self.install_guards(GuardBuilder.ID_MATCH) + return GetAttrVariable( + BuiltinVariable(float, source=self.source), + value.__name__, + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9269043968873..137108f5fac3c 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1271,6 +1271,19 @@ def call_method( args[1:], ) + if self.fn is float and len(args) == 1 and name in ("fromhex", "hex"): + if isinstance(args[0], ConstantVariable): + try: + fn = getattr(float, name) + res = fn(args[0].as_python_constant()) + return variables.ConstantVariable.create(res) + except (OverflowError, ValueError) as e: + raise_observed_exception( + type(e), + tx, + args=list(map(ConstantVariable.create, e.args)), + ) + if self.fn is object and name == "__init__": # object.__init__ is a no-op return variables.ConstantVariable(None) From 00da8e63ebb3bea5cf4382ea37ad1ae5598ac90d Mon Sep 17 00:00:00 2001 From: iremyux Date: Wed, 23 Jul 2025 16:12:17 +0000 Subject: [PATCH 444/457] CI for Windows Arm64 (#148753) This pull request adds a new CI workflow for Windows Arm64, named win-arm64-build-test.yml. It can be triggered on any pull request by including the ciflow/win-arm64 tag. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148753 Approved by: https://github.com/malfet --- .ci/pytorch/win-arm64-build.ps1 | 34 ++++ .ci/pytorch/win-arm64-test.sh | 24 +++ .../win-test-helpers/arm64/build_pytorch.ps1 | 98 +++++++++ .github/pytorch-probot.yml | 1 + .github/workflows/win-arm64-build-test.yml | 187 ++++++++++++++++++ 5 files changed, 344 insertions(+) create mode 100644 .ci/pytorch/win-arm64-build.ps1 create mode 100644 .ci/pytorch/win-arm64-test.sh create mode 100644 .ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 create mode 100644 .github/workflows/win-arm64-build-test.yml diff --git a/.ci/pytorch/win-arm64-build.ps1 b/.ci/pytorch/win-arm64-build.ps1 new file mode 100644 index 0000000000000..2cb162b8a301c --- /dev/null +++ b/.ci/pytorch/win-arm64-build.ps1 @@ -0,0 +1,34 @@ +# If you want to rebuild, run this with $env:REBUILD=1 +# If you want to build with CUDA, run this with $env:USE_CUDA=1 +# If you want to build without CUDA, run this with $env:USE_CUDA=0 + +# Check for setup.py in the current directory +if (-not (Test-Path "setup.py")) { + Write-Host "ERROR: Please run this build script from PyTorch root directory." + exit 1 +} + +# Get the script's parent directory +$ScriptParentDir = Split-Path -Parent $MyInvocation.MyCommand.Definition + +# Set TMP_DIR and convert to Windows path +$env:TMP_DIR = Join-Path (Get-Location) "build\win_tmp" +$env:TMP_DIR_WIN = $env:TMP_DIR # Already in Windows format, no cygpath needed + +# Set final package directory with default fallback +if (-not $env:PYTORCH_FINAL_PACKAGE_DIR) { + $env:PYTORCH_FINAL_PACKAGE_DIR = "C:\w\build-results" +} + +# Create the final package directory if it doesn't exist +if (-not (Test-Path $env:PYTORCH_FINAL_PACKAGE_DIR)) { + New-Item -Path $env:PYTORCH_FINAL_PACKAGE_DIR -ItemType Directory -Force | Out-Null +} + +# Set script helpers directory +$env:SCRIPT_HELPERS_DIR = Join-Path $ScriptParentDir "win-test-helpers\arm64" + +# Run the main build script +& "$env:SCRIPT_HELPERS_DIR\build_pytorch.ps1" + +Write-Host "BUILD PASSED" diff --git a/.ci/pytorch/win-arm64-test.sh b/.ci/pytorch/win-arm64-test.sh new file mode 100644 index 0000000000000..662c561aa8962 --- /dev/null +++ b/.ci/pytorch/win-arm64-test.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -ex -o pipefail + +SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) +# shellcheck source=./common.sh +source "$SCRIPT_PARENT_DIR/common.sh" + +run_tests() { + echo Running smoke_test.py... + python ./.ci/pytorch/smoke_test/smoke_test.py --package torchonly + + echo Running test_autograd.oy, test_nn.py, test_torch.py... + cd test + + CORE_TEST_LIST=("test_autograd.py" "test_nn.py" "test_modules.py") + + for t in "${CORE_TEST_LIST[@]}"; do + echo "Running test: $t" + python "$t" --verbose --save-xml --use-pytest -vvvv -rfEsxXP -p no:xdist + done +} + +run_tests +echo "TEST PASSED" diff --git a/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 new file mode 100644 index 0000000000000..29b3e913439cb --- /dev/null +++ b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 @@ -0,0 +1,98 @@ +# TODO: we may can use existing build_pytorch.bat for arm64 + +if ($env:DEBUG -eq "1") { + $env:BUILD_TYPE = "debug" +} else { + $env:BUILD_TYPE = "release" +} + +# This inflates our log size slightly, but it is REALLY useful to be +# able to see what our cl.exe commands are. (since you can actually +# just copy-paste them into a local Windows setup to just rebuild a +# single file.) +# log sizes are too long, but leaving this here in case someone wants to use it locally +# $env:CMAKE_VERBOSE_MAKEFILE = "1" + +$env:INSTALLER_DIR = Join-Path $env:SCRIPT_HELPERS_DIR "installation-helpers" + +cd .. + +# Environment variables +$env:SCCACHE_IDLE_TIMEOUT = "0" +$env:SCCACHE_IGNORE_SERVER_IO_ERROR = "1" +$env:CMAKE_BUILD_TYPE = $env:BUILD_TYPE +$env:CMAKE_C_COMPILER_LAUNCHER = "sccache" +$env:CMAKE_CXX_COMPILER_LAUNCHER = "sccache" +$env:libuv_ROOT = Join-Path $env:DEPENDENCIES_DIR "libuv\install" +$env:MSSdk = "1" + +if ($env:PYTORCH_BUILD_VERSION) { + $env:PYTORCH_BUILD_VERSION = $env:PYTORCH_BUILD_VERSION + $env:PYTORCH_BUILD_NUMBER = "1" +} + +$env:CMAKE_POLICY_VERSION_MINIMUM = "3.5" + +# Set BLAS type +if ($env:ENABLE_APL -eq "1") { + $env:BLAS = "APL" + $env:USE_LAPACK = "1" +} elseif ($env:ENABLE_OPENBLAS -eq "1") { + $env:BLAS = "OpenBLAS" + $env:OpenBLAS_HOME = Join-Path $env:DEPENDENCIES_DIR "OpenBLAS\install" +} + +# Change to source directory +Set-Location $env:PYTORCH_ROOT + +# Copy libuv.dll +Copy-Item -Path (Join-Path $env:libuv_ROOT "lib\Release\uv.dll") -Destination "torch\lib\uv.dll" -Force + +# Create virtual environment +python -m venv .venv +.\.venv\Scripts\Activate.ps1 +where.exe python + +# Python install dependencies +python -m pip install --upgrade pip +pip install setuptools pyyaml +pip install -r requirements.txt + +# Set after installing psutil +$env:DISTUTILS_USE_SDK = "1" + +# Print all environment variables +Get-ChildItem Env: + +# Start and inspect sccache +sccache --start-server +sccache --zero-stats +sccache --show-stats + +# Build the wheel +python setup.py bdist_wheel +if ($LASTEXITCODE -ne 0) { exit 1 } + +# Install the wheel locally +$whl = Get-ChildItem -Path "dist\*.whl" | Select-Object -First 1 +if ($whl) { + python -mpip install --no-index --no-deps $whl.FullName +} + +# Copy final wheel +robocopy "dist" "$env:PYTORCH_FINAL_PACKAGE_DIR" *.whl + +# Export test times +python tools/stats/export_test_times.py + +# Copy additional CI files +robocopy ".additional_ci_files" "$env:PYTORCH_FINAL_PACKAGE_DIR\.additional_ci_files" /E + +# Save ninja log +Copy-Item -Path "build\.ninja_log" -Destination $env:PYTORCH_FINAL_PACKAGE_DIR -Force + +# Final sccache stats and stop +sccache --show-stats +sccache --stop-server + +exit 0 diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 5288aca852931..a5982b63b70fc 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -31,6 +31,7 @@ ciflow_push_tags: - ciflow/pull - ciflow/h100 - ciflow/h100-distributed +- ciflow/win-arm64 - ciflow/h100-symm-mem - ciflow/h100-cutlass-backend retryable_workflows: diff --git a/.github/workflows/win-arm64-build-test.yml b/.github/workflows/win-arm64-build-test.yml new file mode 100644 index 0000000000000..627a43b56bf70 --- /dev/null +++ b/.github/workflows/win-arm64-build-test.yml @@ -0,0 +1,187 @@ +name: windows-arm64-build-test + +on: + push: + tags: + - ciflow/win-arm64/* + +env: + GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} + PYTHON_VERSION: "3.12" + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + DOWNLOADS_DIR: c:\temp\downloads + DEPENDENCIES_DIR: c:\temp\dependencies + ENABLE_APL: 1 + ENABLE_OPENBLAS: 0 + BUILD_TYPE: release + +permissions: + id-token: write + contents: read + +jobs: + build: + # Don't run on forked repos. + if: github.repository_owner == 'pytorch' + runs-on: "windows-11-arm64-preview" + timeout-minutes: 240 + steps: + - name: configure aws credentials + id: aws_creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_sscache + aws-region: us-east-1 + role-duration-seconds: 18000 + + - name: Enable long paths + shell: cmd + run: | + git config --system --get core.longpaths || echo "core.longpaths is not set, setting it now" + git config --system core.longpaths true + + - name: Git checkout PyTorch + uses: actions/checkout@v4 + with: + path: pytorch + submodules: recursive + + - name: Bootstrap Python + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_python.bat" + + - name: Parse ref + id: parse-ref + shell: bash + run: python pytorch/.github/scripts/parse_ref.py + + - name: Get workflow job id + shell: bash + id: get-job-id + run: | + set -eux + python pytorch/.github/scripts/get_workflow_job_id.py "${GITHUB_RUN_ID}" "${RUNNER_NAME}" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Bootstrap APL + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_apl.bat" + + - name: Bootstrap Rust + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_rust.bat" + + - name: Bootstrap sccache + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_sccache.bat" + + - name: Bootstrap Libuv + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_libuv.bat" + + - name: Build + id: build + shell: cmd + env: + PYTORCH_FINAL_PACKAGE_DIR: C:/${{ github.run_id }}/build-results/ + BRANCH: ${{ steps.parse-ref.outputs.branch }} + BUILD_WHEEL: 1 + MAX_JOBS: 8 + PYTHON_VERSION: "3.12" + SCCACHE_BUCKET: "ossci-compiler-cache" + SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} + SCCACHE_REGION: us-east-1 + VC_PRODUCT: "BuildTools" + VC_VERSION: "" + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + AWS_DEFAULT_REGION: us-east-1 + USE_CUDA: '0' + USE_XPU: '0' + OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + run: | + cd pytorch + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" arm64 + powershell -ExecutionPolicy Bypass -File ".ci/pytorch/win-arm64-build.ps1" + + - name: Upload artifacts + uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: torch-wheel-win-arm64-py3-12 + retention-days: 14 + if-no-files-found: error + path: C:\${{ github.run_id }}\build-results + + test: + if: github.repository_owner == 'pytorch' + strategy: + fail-fast: false + runs-on: "windows-11-arm64-preview" + needs: build + steps: + - name: Enable long paths + shell: cmd + run: | + git config --system --get core.longpaths || echo "core.longpaths is not set, setting it now" + git config --system core.longpaths true + + - name: Git checkout PyTorch + uses: actions/checkout@v4 + with: + path: pytorch + submodules: recursive + + - name: Bootstrap Python + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_python.bat" + + - name: Bootstrap Rust + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_rust.bat" + + - name: Get workflow job id + shell: bash + id: get-job-id + run: | + set -eux + python pytorch/.github/scripts/get_workflow_job_id.py "${GITHUB_RUN_ID}" "${RUNNER_NAME}" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Download Build Artifacts + uses: actions/download-artifact@v4.1.7 + with: + name: torch-wheel-win-arm64-py3-12 + path: C:\${{ github.run_id }}\build-results + + - name: Test + id: test + shell: cmd + env: + USE_CUDA: '0' + INSTALL_WINDOWS_SDK: 1 + PYTHON_VERSION: "3.12" + VC_PRODUCT: "BuildTools" + AWS_DEFAULT_REGION: us-east-1 + GITHUB_REPOSITORY: ${{ github.repository }} + GITHUB_WORKFLOW: ${{ github.workflow }} + GITHUB_JOB: ${{ github.job }} + GITHUB_RUN_ID: ${{ github.run_id }} + GITHUB_RUN_NUMBER: ${{ github.run_number }} + GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} + JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} + PYTORCH_FINAL_PACKAGE_DIR: C:/${{ github.run_id }}/build-results/ + run: | + mkdir "%PYTORCH_FINAL_PACKAGE_DIR%" + call pytorch/.ci/pytorch/windows/arm64/bootstrap_tests.bat + set GIT_BASH=C:\Program Files\Git\usr\bin\bash.exe + "%GIT_BASH%" -c "bash --noprofile --norc .ci/pytorch/win-arm64-test.sh" \ No newline at end of file From 5e386eec9426f174eea130c0c012d9f65ebe65fb Mon Sep 17 00:00:00 2001 From: Xu Han Date: Wed, 23 Jul 2025 16:29:15 +0000 Subject: [PATCH 445/457] [AOTI] enable aot inductor on Windows (#158915) With many PRs landed, we can run the first aot inductor example on Windows. image Let's remove the Windows check on `AotCodeCompiler`. CC: @angelayi , @desertfire , @jansel Pull Request resolved: https://github.com/pytorch/pytorch/pull/158915 Approved by: https://github.com/desertfire --- torch/_inductor/codecache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 442d36e0d117e..6ba9147c4e9bd 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1643,9 +1643,6 @@ def compile( """ generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment] - if sys.platform == "win32": - raise RuntimeError("AotCodeCompiler not yet supported for inductor") - _set_gpu_runtime_env() # cpp_extension consults the env picked_vec_isa = pick_vec_isa() From 1b456c580d8d2b85e5eeb3e8ca92d5284e0e9156 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 22 Jul 2025 23:45:05 -0700 Subject: [PATCH 446/457] [dynamo][guards] Add type info of the guarded value in guard managers (#158765) tlparse looks like this image This will aid in reading guards. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158765 Approved by: https://github.com/Lucaskabela, https://github.com/StrongerXi --- test/dynamo/test_guard_manager.py | 63 ++++++++++++++++- test/dynamo/test_repros.py | 2 +- torch/_dynamo/guards.py | 5 +- torch/csrc/dynamo/guards.cpp | 108 +++++++++++++++++++++++++++--- 4 files changed, 163 insertions(+), 15 deletions(-) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 1aeafaf5dd33c..83d7cec8a7f18 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -811,7 +811,7 @@ def test_clone(self): except ImportError: from utils import install_guard_manager_testing_hook - def hook(guard_wrapper, f_locals): + def hook(guard_wrapper, f_locals, builder): root = guard_wrapper.root # Check full cloning works as expected @@ -851,7 +851,7 @@ def test_diff_guard_manager(self): from utils import install_guard_manager_testing_hook counter = 0 - def hook(guard_wrapper, f_locals): + def hook(guard_wrapper, f_locals, builder): nonlocal counter root = guard_wrapper.root diff_guard_root = guard_wrapper.diff_guard_root @@ -898,6 +898,65 @@ def fn(x, foo, bar): opt_fn(x, foo, bar) +class TypePropagationTests(torch._dynamo.test_case.TestCase): + @torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True) + def test_basic_types(self): + class Foo: + def __init__(self): + self.x = {"a": 2} + self.y = torch.randn(4) + self.z = {} + + foo = Foo() + + mod = torch.nn.Linear(4, 4) + + def fn(x): + return x + foo.x["a"] + foo.y + mod(x) + + try: + from .utils import install_guard_manager_testing_hook + except ImportError: + from utils import install_guard_manager_testing_hook + + def hook(guard_wrapper, f_locals, builder): + from torch._dynamo.source import AttrSource, DictGetItemSource, LocalSource + + foo_source = LocalSource("foo") + foo_x_source = AttrSource(foo_source, "x") + + self.assertTrue(builder.get(foo_source.name()) is foo) + self.assertTrue(builder.get(foo_x_source.name()) is foo.x) + + # Check types of foo.x + foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source) + self.assertTrue(foo_x_mgr.is_guarded_value_dict()) + + # Check types of foo.x["a"] + foo_x_a_source = DictGetItemSource(foo_x_source, "a") + foo_x_a_mgr = builder.get_guard_manager_from_source(foo_x_a_source) + self.assertTrue(foo_x_a_mgr.is_guarded_value_immutable()) + + # Check types of foo.y + foo_y_source = AttrSource(foo_source, "y") + foo_y_mgr = builder.get_guard_manager_from_source(foo_y_source) + self.assertTrue(foo_y_mgr.is_guarded_value_immutable()) + + # Check types of foo.z + foo_z_source = AttrSource(foo_source, "z") + foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source) + self.assertTrue(foo_z_mgr.is_guarded_value_empty_dict()) + + # Check types of mod + mod_source = LocalSource("mod") + mod_mgr = builder.get_guard_manager_from_source(mod_source) + self.assertTrue(mod_mgr.is_guarded_value_nn_module()) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with install_guard_manager_testing_hook(hook): + opt_fn(torch.randn(4, 4)) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e0b2fdbf8611a..db1288fe5bf9b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -992,7 +992,7 @@ def tearDown(self) -> None: self.exit_stack.close() super().tearDown() - def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals): + def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder): root = guard_manager_wrapper.root cloned_root = root.clone_manager(lambda x: True) cloned_wrapper = torch._dynamo.guards.GuardManagerWrapper(cloned_root) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 7b1203bae265d..e8ddb9b31e3af 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -166,7 +166,7 @@ ) -guard_manager_testing_hook_fn: Optional[Callable[[Any, Any], Any]] = None +guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None try: import numpy as np @@ -311,6 +311,7 @@ def get_manager_line(self, guard_manager, accessor_str=None): s = t + ": source=" + source if accessor_str: s += ", " + accessor_str + s += f", type={guard_manager.type_of_guarded_value()}" return s def construct_dict_manager_string(self, mgr, body): @@ -2969,7 +2970,7 @@ def make_guard_filter_entry(guard): if guard_manager_testing_hook_fn is not None: guard_manager_testing_hook_fn( - self.guard_manager, output_graph.local_scope + self.guard_manager, output_graph.local_scope, builder ) # NB for developers: n_iters is chosen to be 1 to prevent excessive diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index e8a2ebfce6f77..eb0f20f1c86eb 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -1139,6 +1139,22 @@ std::string get_exception_message() { return exc_message; } +bool is_nn_module(py::handle example_value) { + py::object torch_module_cls = py::module_::import("torch.nn").attr("Module"); + return py::isinstance(example_value, torch_module_cls); +} + +std::string get_type_str(py::handle example_value) { + std::string type_name; + try { + type_name = py::str(py::type::of(example_value)).cast(); + } catch (const py::error_already_set& e) { + // Fallback that never throws in release builds + type_name = ""; + } + return type_name; +} + bool is_immutable_object(py::handle example_value) { py::object config_module = py::module_::import("torch._dynamo.config"); @@ -2554,9 +2570,13 @@ class GuardManager { py::handle example_value) : _root(root), _source(std::move(source)), - _is_dict(py::isinstance(example_value)) { + _is_dict(py::isinstance(example_value)), + _is_immutable(is_immutable_object(example_value)), + _is_nn_module(is_nn_module(example_value)), + _type_str(get_type_str(example_value)) { if (_is_dict) { _dict_tag = get_dict_version_unchecked(example_value.ptr()); + _is_empty_dict = PyDict_Size(example_value.ptr()) == 0; } } @@ -2576,10 +2596,45 @@ class GuardManager { _leaf_guards.emplace_back(std::move(leaf_guard)); } + public: + // type related helpers + bool is_guarded_value_immutable() { + return _is_immutable; + } + + bool is_guarded_value_nn_module() { + return _is_nn_module; + } + + bool is_guarded_value_dict() { + return _is_dict; + } + + bool is_guarded_value_empty_dict() { + return _is_empty_dict; + } + + std::string type_of_guarded_value() { + return _type_str; + } + public: // For cloning - GuardManager(RootGuardManager* root, std::string source, bool is_dict) - : _root(root), _source(std::move(source)), _is_dict(is_dict) {} + GuardManager( + RootGuardManager* root, + std::string source, + bool is_dict, + bool is_empty_dict, + bool is_immutable, + bool is_nn_module, + std::string type_str) + : _root(root), + _source(std::move(source)), + _is_dict(is_dict), + _is_empty_dict(is_empty_dict), + _is_immutable(is_immutable), + _is_nn_module(is_nn_module), + _type_str(std::move(type_str)) {} void clone_common( RootGuardManager* cloned_root, @@ -2610,7 +2665,14 @@ class GuardManager { if (!py::cast(clone_filter_fn(this))) { return nullptr; } - GuardManager* cloned_mgr = new GuardManager(cloned_root, _source, _is_dict); + GuardManager* cloned_mgr = new GuardManager( + cloned_root, + _source, + _is_dict, + _is_empty_dict, + _is_immutable, + _is_nn_module, + _type_str); clone_common(cloned_root, cloned_mgr, clone_filter_fn); return cloned_mgr; } @@ -2890,7 +2952,11 @@ class GuardManager { // to enable fail fast for the next check. std::vector> _accessors; - bool _is_dict; + bool _is_dict = false; + bool _is_empty_dict = false; + bool _is_immutable = false; + bool _is_nn_module = false; + std::string _type_str; uint64_t _dict_tag{0}; }; @@ -3176,7 +3242,7 @@ class DictGuardManager : public GuardManager { RootGuardManager* root, std::string source, py::handle example_value) - : GuardManager(root, std::move(source)), + : GuardManager(root, std::move(source), example_value), _size(PyDict_Size(example_value.ptr())), _expected_type(Py_TYPE(example_value.ptr())), _is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {} @@ -3391,8 +3457,17 @@ class DictGuardManager : public GuardManager { Py_ssize_t size, PyTypeObject* expected_type, bool is_exact_dict_type, - std::vector indices) - : GuardManager(cloned_root, std::move(source), true), + std::vector indices, + std::string type_of, + bool is_empty_dict) + : GuardManager( + cloned_root, + std::move(source), + true, // _is_dict + is_empty_dict, + false, // _is_nn_module + false, // _is_immutable + std::move(type_of)), _size(size), _expected_type(expected_type), _is_exact_dict_type(is_exact_dict_type), @@ -3411,7 +3486,9 @@ class DictGuardManager : public GuardManager { _size, _expected_type, _is_exact_dict_type, - _indices); + _indices, + type_of_guarded_value(), + is_guarded_value_empty_dict()); clone_common(cloned_root, cloned_mgr, clone_filter_fn); for (auto index : _indices) { @@ -3534,7 +3611,7 @@ std::unique_ptr make_guard_manager( throw py::type_error("Invalid guard manager enum"); } } - return std::make_unique(root, std::move(source)); + return std::make_unique(root, std::move(source), example_value); } class TORCH_FUNCTION_MODE_STACK : public LeafGuard { @@ -5925,6 +6002,17 @@ PyObject* torch_c_dynamo_guards_init() { // return by reference because GuardManager has the ownership of accessors .def("get_source", &GuardManager::get_source) .def("fail_count", &GuardManager::fail_count) + .def( + "is_guarded_value_immutable", + &GuardManager::is_guarded_value_immutable) + .def( + "is_guarded_value_nn_module", + &GuardManager::is_guarded_value_nn_module) + .def("is_guarded_value_dict", &GuardManager::is_guarded_value_dict) + .def( + "is_guarded_value_empty_dict", + &GuardManager::is_guarded_value_empty_dict) + .def("type_of_guarded_value", &GuardManager::type_of_guarded_value) .def( "get_accessors", &GuardManager::get_accessors, From 41b6cdaf76180a3d1308c898c094736305c7ceec Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 23 Jul 2025 17:42:10 +0000 Subject: [PATCH 447/457] Revert "Fix Triton GEMM templates with k=1 (#158650)" This reverts commit 9df0f565972a8a034fd77d65aff2c53e6e9856d1. Reverted https://github.com/pytorch/pytorch/pull/158650 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, see D78805560 for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/158650#issuecomment-3109538827)) --- test/inductor/test_max_autotune.py | 19 ------------------- torch/_inductor/kernel/mm.py | 4 ++-- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 43ed8eda83084..096e924a47826 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1551,25 +1551,6 @@ def f(a, b): if "benchmark_gpu" in counter: self.assertEqual(counters["inductor"][counter], 2) - @config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "TRITON", - } - ) - def test_mm_k_1(self): - def mm(x, y): - return x @ y - - for i in range(90, 100): - torch._dynamo.reset() - a = torch.randn((i, 1), device="cuda", dtype=torch.float32) - b = torch.randn((1, i), device="cuda", dtype=torch.float32) - compiled_f = torch.compile(mm) - - out, code = run_and_get_code(compiled_f, a, b) - torch.testing.assert_close(out, mm(a, b), atol=1e-2, rtol=1e-2) - class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 951494d6c3d55..f1c77afd52fd2 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -110,11 +110,11 @@ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: offs_a_m = rm % M - if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N: offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N From 30b0ad5c683ec0a391ae8b6e12de9fdfced67ddb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 23 Jul 2025 17:47:35 +0000 Subject: [PATCH 448/457] Revert "Fix decorators skipping NCCL tests (#158846)" This reverts commit 57024913c409764f129d6a7792625f5b05462e31. Reverted https://github.com/pytorch/pytorch/pull/158846 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking trunk. See distributed/_composable/fsdp/test_fully_shard_logging.py::LoggingTests::test_fsdp_logging [GH job link](https://github.com/pytorch/pytorch/actions/runs/16472103496/job/46564570609) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/57024913c409764f129d6a7792625f5b05462e31) ([comment](https://github.com/pytorch/pytorch/pull/158846#issuecomment-3109553414)) --- test/distributed/test_functional_api.py | 26 ++++-- torch/testing/_internal/common_distributed.py | 80 +++++++++++++++---- .../_shard/sharded_tensor/__init__.py | 6 +- .../distributed/_tensor/common_dtensor.py | 7 +- .../_internal/distributed/distributed_test.py | 18 +++-- 5 files changed, 100 insertions(+), 37 deletions(-) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 61f52b2dc60ab..3b93e4d2b19ad 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -13,7 +13,6 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_distributed import exit_if_lt_x_gpu from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU @@ -26,7 +25,7 @@ DistributedTestBase, MultiThreadedTestCase, requires_nccl, - skip_if_no_gpu, + TEST_SKIPS, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -477,14 +476,26 @@ def allred_mesh_dim(input): BACKEND = dist.Backend.HCCL +# allows you to check for multiple accelerator irrespective of device type +# to add new device types to this check simply follow the same format +# and append an elif with the conditional and appropriate device count function for your new device +def exit_if_lt_x_accelerators(x): + if TEST_CUDA: + if torch.cuda.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + elif TEST_HPU: + if torch.hpu.device_count() < x: + sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code) + + def with_comms(func=None): if func is None: return partial(with_comms) @wraps(func) def wrapper(self, *args, **kwargs): - if BACKEND == dist.Backend.NCCL: - exit_if_lt_x_gpu(self.world_size) + if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) kwargs["device"] = DEVICE self.pg = self.create_pg(device=DEVICE) @@ -497,9 +508,9 @@ def wrapper(self, *args, **kwargs): class TestCollectivesWithDistributedBackend(DistributedTestBase): - @skip_if_no_gpu @with_comms() def test_all_gather_into_tensor_coalesced(self, device): + exit_if_lt_x_accelerators(self.world_size) tensors = [ torch.ones([4], device=device), torch.ones([4], device=device) + 1, @@ -571,8 +582,9 @@ def allreduce(t, pg): compiled_allreduce(torch.randn(8, device=device), self.pg) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_no_gpu def test_tracing_with_fakepg(self, device=DEVICE): + exit_if_lt_x_accelerators(self.world_size) + def allreduce(t, pg): return ft_c.all_reduce(t, "sum", pg) @@ -614,9 +626,9 @@ class TestDistributedBackendCollectivesWithWorldSize4( def world_size(self): return 4 - @skip_if_no_gpu @with_comms() def test_permute_tensor_with_sub_group(self, device): + exit_if_lt_x_accelerators(self.world_size) mesh_dim_names = ["dp", "tp"] mesh_2d = dt.init_device_mesh( diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index c13d7f3c42c14..9b311411e34a2 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -118,17 +118,14 @@ def requires_ddp_rank(device): return device in DDP_RANK_DEVICES -def exit_if_lt_x_gpu(x): - if torch.cuda.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) - - def skip_if_no_gpu(func): """Skips if the world size exceeds the number of GPUs, ensuring that if the test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" @wraps(func) def wrapper(*args, **kwargs): + if not (TEST_CUDA or TEST_HPU or TEST_XPU): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) world_size = int(os.environ["WORLD_SIZE"]) if TEST_CUDA and torch.cuda.device_count() < world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{world_size}"].exit_code) @@ -139,9 +136,7 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) - return unittest.skipUnless( - TEST_CUDA or TEST_HPU or TEST_XPU, TEST_SKIPS["no_cuda"].message - )(wrapper) + return wrapper # TODO (kwen2501): what is the purpose of this decorator? Tests with this @@ -173,16 +168,33 @@ def wrapper(*args, **kwargs): def require_n_gpus_for_nccl_backend(n, backend): - return skip_if_lt_x_gpu(n) if backend == "nccl" else unittest.skipIf(False, None) + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend == "nccl" and torch.cuda.device_count() < n: + sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) + else: + return func(*args, **kwargs) + + return wrapper + + return decorator def import_transformers_or_skip(): - try: - from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + from transformers import AutoModelForMaskedLM, BertConfig # noqa: F401 - return unittest.skipIf(False) - except ImportError: - return unittest.skip(TEST_SKIPS["importerror"].message) + return func(*args, **kwargs) + except ImportError: + sys.exit(TEST_SKIPS["importerror"].exit_code) + + return wrapper + + return decorator def at_least_x_gpu(x): @@ -196,7 +208,36 @@ def at_least_x_gpu(x): def skip_if_lt_x_gpu(x): - return unittest.skipUnless(at_least_x_gpu(x), TEST_SKIPS[f"multi-gpu-{x}"].message) + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + if TEST_HPU and torch.hpu.device_count() >= x: + return func(*args, **kwargs) + if TEST_XPU and torch.xpu.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator + + +# This decorator helps avoiding initializing cuda while testing other backends +def nccl_skip_if_lt_x_gpu(backend, x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if backend != "nccl": + return func(*args, **kwargs) + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) + + return wrapper + + return decorator def verify_ddp_error_logged(model_DDP, err_substr): @@ -372,7 +413,14 @@ def requires_multicast_support(): def skip_if_rocm_multiprocess(func): """Skips a test for ROCm""" func.skip_if_rocm_multiprocess = True - return unittest.skipUnless(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func) + + @wraps(func) + def wrapper(*args, **kwargs): + if not TEST_WITH_ROCM: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS["skipIfRocm"].exit_code) + + return wrapper def skip_if_win32(): diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index 8b52acdbeeb04..60c744ac1a84c 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -7,8 +7,8 @@ import torch.distributed as dist from torch.distributed import rpc from torch.testing._internal.common_distributed import ( - exit_if_lt_x_gpu, MultiProcessTestCase, + TEST_SKIPS, tp_transports, ) @@ -94,8 +94,8 @@ def with_comms(func=None, init_rpc=True, backend="nccl"): @wraps(func) def wrapper(self, *args, **kwargs): - if backend == "nccl": - exit_if_lt_x_gpu(self.world_size) + if backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) self.init_comms(init_rpc=init_rpc, backend=backend) func(self, *args, **kwargs) self.destroy_comms(destroy_rpc=init_rpc) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index f84d326ae3bf6..94bfead8a0c03 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -3,6 +3,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import itertools +import sys from collections.abc import Iterator, Sequence from dataclasses import dataclass from functools import partial, wraps @@ -30,11 +31,11 @@ SequenceParallel, ) from torch.testing._internal.common_distributed import ( - exit_if_lt_x_gpu, MultiProcessTestCase, MultiThreadedTestCase, run_subtests, skip_if_lt_x_gpu, + TEST_SKIPS, ) from torch.testing._internal.common_utils import TEST_CUDA, TEST_HPU, TEST_XPU from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec @@ -355,8 +356,8 @@ def build_device_mesh(self) -> DeviceMesh: return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init) -> None: - if "nccl" in self.backend: - exit_if_lt_x_gpu(self.world_size) + if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) if self.backend not in [ "nccl", diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index c2ff09d9297f1..28b761a37d58c 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -59,10 +59,10 @@ captured_output, cleanup_temp_dir, DistTestCases, - exit_if_lt_x_gpu, init_multigpu_helper, initialize_temp_directories, MultiProcessTestCase, + nccl_skip_if_lt_x_gpu, require_n_gpus_for_nccl_backend, requires_nccl_version, simple_sparse_reduce_tests, @@ -601,8 +601,10 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): self.rank = rank self.file_name = file_name - if torch.cuda.is_available(): - exit_if_lt_x_gpu(int(self.world_size)) + if torch.cuda.is_available() and torch.cuda.device_count() < int( + self.world_size + ): + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) try: pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout) timeout = timedelta(seconds=pg_timeout_seconds) @@ -5334,7 +5336,7 @@ def step_model(model, input, target): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5345,7 +5347,7 @@ def test_accumulate_gradients_no_sync(self): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_grad_is_view(self): """ Runs _test_accumulate_gradients_no_sync using default inputs @@ -5356,7 +5358,7 @@ def test_accumulate_gradients_no_sync_grad_is_view(self): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_allreduce_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync @@ -5384,7 +5386,7 @@ def allreduce_hook( BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): """ Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce @@ -5418,7 +5420,7 @@ def div(fut): BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", "get_future is only supported on mpi, nccl and gloo", ) - @require_n_gpus_for_nccl_backend(2, BACKEND) + @nccl_skip_if_lt_x_gpu(BACKEND, 2) def test_get_future(self): def mult(fut): return [t * 3 for t in fut.wait()] From 9905ed616a65a3195c7ebc2bd44301c2c442f050 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Wed, 23 Jul 2025 06:20:10 -0700 Subject: [PATCH 449/457] [Inductor] Expose decomposeK knobs as envvars (#158745) Fix up decomposeK autotuning, by removing condition to return more than `k_splits_limit` and setting default to 10 instead of 5. Allow `k_splits_limit` to be configurable to the user via `TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS` and also allow user to configure threshold in which to use decompose_k via `TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158745 Approved by: https://github.com/eellison --- test/inductor/test_max_autotune.py | 56 +++++++++++++++++++++++------- torch/_inductor/config.py | 14 ++++++-- torch/_inductor/utils.py | 33 ++++++++---------- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 096e924a47826..e451067be59a0 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -50,7 +50,12 @@ aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_cache, run_and_get_code +from torch._inductor.utils import ( + fresh_cache, + get_k_splits, + run_and_get_code, + use_decompose_k_choice, +) from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck @@ -1494,7 +1499,9 @@ def misses(): self.assertEqual(hits(), 4) self.assertEqual(misses(), 4) + @fresh_cache() @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1502,19 +1509,42 @@ def misses(): max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, - disable_decompose_k=True, ) - def test_max_autotune_disable_decompose_K(self): - M, N, K = (32, 32, 32768) - - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) - - compiled_func = torch.compile(lambda a, b: a @ b) - out, code = run_and_get_code(compiled_func, a, b) - - for codegen in code: - FileCheck().check_not("decompose_k").run(codegen) + @parametrize("num_decompose_k_splits", (0, 5, 20)) + @parametrize("decompose_k_threshold", (8, 16)) + def test_max_autotune_decompose_k_envvars( + self, num_decompose_k_splits, decompose_k_threshold + ): + shapes = [(32, 32, 32768), (32, 32, 256)] + for M, N, K in shapes: + get_k_splits.cache_clear() + use_decompose_k_choice.cache_clear() + a = torch.randn(M, K, dtype=torch.float16, device="cuda") + b = torch.randn(K, N, dtype=torch.float16, device="cuda") + + with config.patch( + { + "triton.num_decompose_k_splits": num_decompose_k_splits, + "triton.decompose_k_threshold": decompose_k_threshold, + } + ): + compiled_func = torch.compile(lambda a, b: a @ b) + _, code = run_and_get_code(compiled_func, a, b) + + decompose_count = 0 + for codegen in code: + if "benchmark_decompose_k_mm" in codegen: + decompose_count += 1 + + if ( + K // M < decompose_k_threshold + or K // N < decompose_k_threshold + or num_decompose_k_splits == 0 + ): + self.assertEqual(decompose_count, 0) + else: + self.assertTrue(decompose_count > 0) + self.assertTrue(decompose_count <= num_decompose_k_splits) @skipIfXpu @unittest.skipIf( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2404f397ba54c..ae2ee6a574c73 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -425,9 +425,6 @@ def prologue_fusion_enabled() -> bool: # enable slow autotuning passes to select gemm algorithms max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" -# disable decomposek autotune choice for gemm -disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" - # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 @@ -1345,6 +1342,17 @@ class triton: # Note: it may also need to be used with config.compile_threads = 1 disallow_failing_autotune_kernels_TESTING_ONLY = False + # specify number of splits to autotune on for decompose_k. 0 disables decompose_k + num_decompose_k_splits = int( + os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") + ) + + # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables + # it as an autotuning choice for all matmuls + decompose_k_threshold = int( + os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") + ) + class aot_inductor: """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d95642b75f9d1..aef81712d17eb 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1665,20 +1665,15 @@ def _use_cutlass_for_op(op_name: str) -> bool: return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] -decompose_k_threshold = 32 - -# To limit compile time -k_splits_limit = 5 - -# Hand-tuned -default_k_splits = [16, 32, 64, 128, 256] - _IntLike: TypeAlias = Union[int, sympy.Expr] +@functools.cache def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: from torch._inductor.virtualized import V + decompose_k_threshold = config.triton.decompose_k_threshold + return ( not torch.version.hip and V.graph.sizevars.statically_known_true( @@ -1689,15 +1684,21 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode # TODO: Support AOTI for decomposeK and not V.graph.cpp_wrapper - and not config.disable_decompose_k ) @functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: + # To limit compile time + k_splits_limit = config.triton.num_decompose_k_splits + + # Hand-tuned + default_k_splits = [16, 32, 64, 128, 256] # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: return default_k_splits + elif k_splits_limit == 0: + return [] if (isinstance(m, sympy.Expr) and not m.is_number) or ( isinstance(n, sympy.Expr) and not n.is_number @@ -1737,15 +1738,10 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: if config.max_autotune_gemm_search_space == "EXHAUSTIVE": return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # If the # of power of 2 divisors are greater than k_splits_limit, return all - # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) - # should never be a massive amount - if len(pow_of_2_divisors) >= k_splits_limit: - return pow_of_2_divisors - else: - best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # Otherwise, conform results to k_splits_limit - return best_splits[:k_splits_limit] + + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] @functools.cache @@ -2020,7 +2016,6 @@ def call(self, *args: Any, **kwargs: Any) -> None: self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) # Skip all the actual compiling. - nonlocal save_output_code save_output_code(wrapper_code.value) if kernel_code: save_output_code(kernel_code.value) From 76be282e3a4893e4c4d2761e862428c615f9e260 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 23 Jul 2025 18:25:46 +0000 Subject: [PATCH 450/457] Revert "[Precompile] Various small bugfixes, add CachingPrecompile to torchbench (#158847)" This reverts commit d898d0d437bfdc0719e6c69d5005606c5e64fca8. Reverted https://github.com/pytorch/pytorch/pull/158847 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI jobs on MI200 and MI300 ([comment](https://github.com/pytorch/pytorch/pull/158847#issuecomment-3109664713)) --- benchmarks/dynamo/common.py | 38 +-------------------- torch/_dynamo/config.py | 2 +- torch/_dynamo/convert_frame.py | 51 +++++++++++++---------------- torch/_dynamo/eval_frame.py | 3 +- torch/_dynamo/guards.py | 5 +-- torch/_dynamo/package.py | 3 +- torch/_dynamo/precompile_context.py | 3 -- 7 files changed, 29 insertions(+), 76 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 69ed64d8489a6..900a93c552b46 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3264,12 +3264,6 @@ def get_example_inputs(self): instead of deleting it and creating a new one.", ) - parser.add_argument( - "--caching-precompile", - action="store_true", - help="Enables caching precompile, serializing artifacts to DynamoCache between runs", - ) - group_latency = parser.add_mutually_exclusive_group() group_latency.add_argument( "--cold-start-latency", @@ -3420,29 +3414,6 @@ def get_example_inputs(self): return parser.parse_args(args) -def process_caching_precompile(): - """ - After every process_entry, save precompile artifacts to DynamoCache - """ - assert torch._dynamo.config.caching_precompile, ( - "Caching precompile should be enabled with --caching-precompile" - ) - from torch._dynamo.precompile_context import PrecompileContext - - # Serialize all callables, clear PrecompileContext - # TODO: put this under torch.compiler API once ready - serialized = PrecompileContext.serialize() - PrecompileContext.clear() - if serialized is not None: - artifacts, info = serialized - print( - f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..." - ) - results = PrecompileContext.deserialize(artifacts) - assert results is not None - PrecompileContext.populate_caches(results) - - def process_entry(rank, runner, original_dir, args): args.rank = rank with maybe_init_distributed( @@ -3451,10 +3422,7 @@ def process_entry(rank, runner, original_dir, args): world_size=args.world_size, port=args.distributed_master_port, ): - result = run(runner, args, original_dir) - if args.caching_precompile: - process_caching_precompile() - return result + return run(runner, args, original_dir) def maybe_fresh_cache(args): @@ -3490,10 +3458,6 @@ def main(runner, original_dir=None, args=None): ) with maybe_fresh_cache(args): - if args.caching_precompile: - os.environ["TORCH_CACHING_PRECOMPILE"] = "1" - torch._dynamo.config.caching_precompile = True - args.init_distributed = args.only and args.multiprocess if args.init_distributed: # NB: Do NOT query device count before CUDA initialization; we're diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index adfd2ab4f00e8..7ef748b85f3e3 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -549,7 +549,7 @@ def default_debug_dir_root() -> str: # Experimental feature for running automatic caching precompile. # Enables automatic DynamoCache save/load -caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1" +caching_precompile = False # Enables the Compiled Autograd engine to trace autograd calls made under torch.compile(). # Note: AOTAutograd will still trace and partition an AOT backward graph local to that diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index bba4d9c980869..149a1c400d99a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -225,31 +225,6 @@ def fx_forward_from_src_skip_result( return result -def log_dynamo_start(code: CodeType, skip: int = 0) -> None: - convert_frame_intern = structured.intern_string(__file__) - # Initialize the ChromiumEventLogger on start - torch._logging.trace_structured( - "dynamo_start", - lambda: { - "stack": list( - itertools.takewhile( - lambda f: f["filename"] != convert_frame_intern, - structured.from_traceback( - CapturedTraceback.extract(skip=4 + skip).summary() - ), - ) - ) - + [ - { - "line": code.co_firstlineno, - "name": code.co_name, - "filename": structured.intern_string(code.co_filename), - } - ] - }, - ) - - def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]: """ Context manager to: @@ -1160,7 +1135,28 @@ def format_func_info(code: CodeType) -> str: # # 2 extra here # torch/_logging/_internal.py:1064 in trace_structured # torch/_dynamo/convert_frame.py:780 in - log_dynamo_start(code, skip) + convert_frame_intern = structured.intern_string(__file__) + # Initialize the ChromiumEventLogger on start + torch._logging.trace_structured( + "dynamo_start", + lambda: { + "stack": list( + itertools.takewhile( + lambda f: f["filename"] != convert_frame_intern, + structured.from_traceback( + CapturedTraceback.extract(skip=4 + skip).summary() + ), + ) + ) + + [ + { + "line": code.co_firstlineno, + "name": code.co_name, + "filename": structured.intern_string(code.co_filename), + } + ] + }, + ) start_time_ns = time.time_ns() fail_type: Optional[str] = None fail_reason: Optional[str] = None @@ -1592,10 +1588,9 @@ def __call__( with compile_lock, _disable_current_modes(): # skip=1: skip this frame - result = self._torchdynamo_orig_backend( + return self._torchdynamo_orig_backend( frame, cache_entry, self.hooks, frame_state, skip=1 ) - return result def catch_errors_wrapper( diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index bfe6801fc4b5d..f47ca4185bed0 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -679,7 +679,8 @@ def get_compiler_config() -> Any: # If self._package is lazily initialized, we should check the dynamo cache now if config.caching_precompile: - if self._package is not None and not self._package.is_initialized(): + assert self._package is not None + if not self._package.is_initialized(): result = DynamoCache.load(fn) if result is None: # Create a fresh CompilePackage diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e8ddb9b31e3af..c6444d3acc6c3 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1970,8 +1970,6 @@ def DUPLICATE_INPUT(self, guard, source_b): if self.serialization_mode == "save": if name := get_local_source_name(source_b): self.check_fn_manager.additional_used_local_vars.add(name) - if name := get_global_source_name(source_b): - self.check_fn_manager.additional_used_global_vars.add(name) ref_a = self.arg_ref(guard) ref_b = self.arg_ref(source_b.name()) @@ -2851,7 +2849,6 @@ def __init__( self.guards_serialization_mode = guards_serialization_mode self.used_builtin_vars: OrderedSet[str] = OrderedSet() self.additional_used_local_vars: OrderedSet[str] = OrderedSet() - self.additional_used_global_vars: OrderedSet[str] = OrderedSet() if runtime_global_scope: assert self.guards_serialization_mode == "load" self.runtime_global_scope = runtime_global_scope @@ -3042,7 +3039,7 @@ def _ref(x): global_scope_state = { k: v for k, v in output_graph_guards_state.global_scope.items() - if k in used_global_vars or k in self.additional_used_global_vars + if k in used_global_vars } global_scope_state[builtins_dict_name] = { k: v diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index a466267035596..be750d41a1dc9 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -380,7 +380,7 @@ def install(self, backends: dict[_BackendId, Any]) -> None: 3. Install the precompiled cache entries to ExtraStates on the code object. """ from torch._C._dynamo.eval_frame import _load_precompile_entry - from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start + from torch._dynamo.convert_frame import get_compile_id from torch._guards import compile_context, CompileContext from .output_graph import get_builtins_dict @@ -394,7 +394,6 @@ def install(self, backends: dict[_BackendId, Any]) -> None: # collapsed into 0/0, 1/0 on warm. increment_frame() compile_id = get_compile_id(frame_state={}) - log_dynamo_start(code) with ( compile_context(CompileContext(compile_id)), dynamo_timed( diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 31d858fe3fc33..040f54ce70db2 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -141,9 +141,6 @@ def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: @classmethod def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: cls._save_artifacts_by_type() - # No need to serialize if there are no new dynamo compiles - if "precompile_dynamo" not in cls._new_cache_artifacts: - return None return super().serialize() @staticmethod From fef236da6924bc8105a830f88060df95ea304de7 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 22 Jul 2025 14:01:54 -0700 Subject: [PATCH 451/457] Add zero_() and empty_like(t) to torch/csrc/stable/ops.h (#158866) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158866 Approved by: https://github.com/janeyx99 --- .../libtorch_agnostic/csrc/kernel.cpp | 29 +++++++++++++++++++ .../libtorch_agnostic/ops.py | 24 +++++++++++++++ .../test/test_libtorch_agnostic.py | 24 +++++++++++++++ torch/csrc/stable/ops.h | 29 +++++++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 63e7821e9dfd4..a46974e511d53 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -269,10 +269,39 @@ void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_out stack[0] = from(res); } +Tensor my_empty_like(Tensor t) { + return empty_like(t); +} + +void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_empty_like(to(stack[0])); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); + m.def("my_empty_like(Tensor t) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_transpose", &boxed_my_transpose); + m.impl("my_empty_like", &boxed_empty_like); +} + + +Tensor my_zero_(Tensor t) { + return zero_(t); +} + +void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_zero_(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { + m.impl("my_zero_", &boxed_my_zero_); } diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 4a193cc73593a..371d8b455e185 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -128,3 +128,27 @@ def my_transpose(t, dim0, dim1) -> Tensor: Returns: my_transpose(t, dim0, dim1) """ return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1) + + +def my_empty_like(t) -> Tensor: + """ + Returns t.empty_like() + + Args: + t: Tensor + + Returns: my_empty_like(t) + """ + return torch.ops.libtorch_agnostic.my_empty_like.default(t) + + +def my_zero_(t) -> Tensor: + """ + Returns t.zero_() + + Args: + t: Tensor + + Returns: my_zero_(t) + """ + return torch.ops.libtorch_agnostic.my_zero_.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 3d9e1ae929289..e1b62a8d3c3c6 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -183,6 +183,30 @@ def test_my_transpose(self, device): with self.assertRaisesRegex(RuntimeError, "API call failed"): libtorch_agnostic.ops.my_transpose(t, 1, 2) + def test_my_empty_like(self, device): + import libtorch_agnostic + + deterministic = torch.are_deterministic_algorithms_enabled() + try: + # set use_deterministic_algorithms to fill unintialized memory + torch.use_deterministic_algorithms(True) + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_empty_like(t) + self.assertTrue(id(out != id(t))) + self.assertEqual(out, torch.empty_like(t)) + finally: + torch.use_deterministic_algorithms(deterministic) + + @onlyCPU + def test_my_zero_(self, device): + import libtorch_agnostic + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_zero_(t) + self.assertEqual(id(out), id(t)) + self.assertEqual(out, torch.zeros_like(t)) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 4105339e569c1..d469abbd55ace 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -3,9 +3,27 @@ #include #include #include +#include using torch::stable::Tensor; +// We expect this to be the stable version of the empty_like op that takes in +// no kwargs (device, dtype, layout, memory_format). We will add kwargs +// support in the future. +inline Tensor empty_like(const Tensor& self) { + const auto num_args = 6; + std::array stack{ + from(self), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::empty_like", "", stack.data())); + return to(stack[0]); +} + // We expect this to be the stable version of the transpose op with identical // semantics to the existing transpose.int op. inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { @@ -15,3 +33,14 @@ inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); return to(stack[0]); } + +// We expect this to be the stable version of the zero_ op with identical +// semantics to the existing zero_ op (except that it will not be called as +// a tensor method but only as a function i.e. zero_(t) not t.zero_()). +inline Tensor zero_(Tensor& self) { + const auto num_args = 1; + std::array stack{from(self)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); + return to(stack[0]); +} From 5fe1f5f6e611daec266d67ff37bc82aa844374a8 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 14 Jul 2025 15:45:17 +0800 Subject: [PATCH 452/457] Device agnostic for DCP --- .../checkpoint/_experimental/test_builder.py | 2 +- .../checkpoint/_experimental/test_staging.py | 14 ++++++------ .../checkpoint/e2e/test_e2e_save_and_load.py | 2 +- .../checkpoint/_experimental/staging.py | 22 ++++++++++--------- torch/distributed/checkpoint/staging.py | 22 ++++++++++--------- 5 files changed, 33 insertions(+), 29 deletions(-) diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py index 7eed02755610b..3c5210bab9a8e 100644 --- a/test/distributed/checkpoint/_experimental/test_builder.py +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -123,7 +123,7 @@ def test_make_async_checkpointer(self) -> None: # Create async checkpointer using factory function with default parameters config: CheckpointerConfig = CheckpointerConfig() config.staging_config = CheckpointStagerConfig( - use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.cuda.is_available(), use_pinned_memory=torch.cuda.is_available(), ) checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info) diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index 0eeba5d63524d..f817718fd53b2 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -74,7 +74,7 @@ def test_cuda_non_blocking_without_cuda(self) -> None: if torch.cuda.is_available(): self.skipTest("CUDA is available, cannot test CUDA unavailable scenario") - options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True) + options = CheckpointStagerConfig(use_non_blocking_copy=True) with self.assertRaises(AssertionError): DefaultStager(options) @@ -86,21 +86,21 @@ def test_different_option_combinations(self) -> None: use_pinned_memory=False, use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), # Only pinned memory CheckpointStagerConfig( use_pinned_memory=True, use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), # Only shared memory CheckpointStagerConfig( use_pinned_memory=False, use_shared_memory=True, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), ] @@ -111,7 +111,7 @@ def test_different_option_combinations(self) -> None: use_pinned_memory=torch.cuda.is_available(), use_shared_memory=False, use_async_staging=True, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ) ) # Only CUDA non-blocking copy @@ -120,7 +120,7 @@ def test_different_option_combinations(self) -> None: use_pinned_memory=torch.cuda.is_available(), use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.cuda.is_available(), ) ) @@ -185,7 +185,7 @@ def test_multiple_staging_operations(self) -> None: use_async_staging=False, use_pinned_memory=torch.cuda.is_available(), use_shared_memory=False, - use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.cuda.is_available(), ) stager = DefaultStager(options) diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index c2e37850d9d70..e1b1041875afb 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -279,7 +279,7 @@ def _run_e2e_test( use_async_staging=zoc, use_shared_memory=use_shared_memory, use_pinned_memory=zoc, - use_cuda_non_blocking_copy=zoc, + use_non_blocking_copy=zoc, ) stager = DefaultStager(staging_options) async_save_response_or_future = saver.async_save( diff --git a/torch/distributed/checkpoint/_experimental/staging.py b/torch/distributed/checkpoint/_experimental/staging.py index 55e4c15921a2d..2907f5c33c965 100644 --- a/torch/distributed/checkpoint/_experimental/staging.py +++ b/torch/distributed/checkpoint/_experimental/staging.py @@ -82,7 +82,7 @@ class CheckpointStagerConfig: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking CUDA memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -93,7 +93,7 @@ class CheckpointStagerConfig: use_pinned_memory: bool = True use_shared_memory: bool = True use_async_staging: bool = True - use_cuda_non_blocking_copy: bool = True + use_non_blocking_copy: bool = True class DefaultStager(CheckpointStager): @@ -153,15 +153,17 @@ def __init__( if self._config.use_async_staging: self._staging_executor = ThreadPoolExecutor(max_workers=1) - if torch.cuda.is_available(): + if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - self._staging_stream = torch.cuda.Stream() + self._staging_stream = torch.Stream() - if self._config.use_cuda_non_blocking_copy: - assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + if self._config.use_non_blocking_copy: + assert torch.accelerator.is_available(), ( + "Non-blocking copy requires CUDA/XPU" + ) def stage( self, @@ -182,16 +184,16 @@ def stage( def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT: state_dict = self._state_dict_stager.stage( - state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs + state_dict, non_blocking=self._config.use_non_blocking_copy, **kwargs ) - if self._config.use_cuda_non_blocking_copy: + if self._config.use_non_blocking_copy: assert self._staging_stream or not self._config.use_async_staging, ( - "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." ) # waits for the enqued copy operations to finish. - self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() return state_dict diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index a2093f803ee6d..e3545600bb4e4 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -110,7 +110,7 @@ class StagingOptions: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking CUDA memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -121,7 +121,7 @@ class StagingOptions: use_pinned_memory: bool = True use_shared_memory: bool = True use_async_staging: bool = True - use_cuda_non_blocking_copy: bool = True + use_non_blocking_copy: bool = True class DefaultStager(AsyncStager): @@ -177,15 +177,17 @@ def __init__( self._staging_stream = None if self._config.use_async_staging: self._staging_executor = ThreadPoolExecutor(max_workers=1) - if torch.cuda.is_available(): + if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - self._staging_stream = torch.cuda.Stream() + self._staging_stream = torch.Stream() - if self._config.use_cuda_non_blocking_copy: - assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + if self._config.use_non_blocking_copy: + assert torch.accelerator.is_available(), ( + "Non-blocking copy requires CUDA/XPU" + ) self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None @@ -216,9 +218,9 @@ def stage( return self._stage(state_dict, **kwargs) def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: - if self._config.use_cuda_non_blocking_copy: + if self._config.use_non_blocking_copy: assert self._staging_stream or not self._config.use_async_staging, ( - "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." ) with ( self._staging_stream @@ -226,10 +228,10 @@ def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: else nullcontext() ): state_dict = self._state_dict_stager.stage( - state_dict, non_blocking=self._config.use_cuda_non_blocking_copy + state_dict, non_blocking=self._config.use_non_blocking_copy ) # waits for the enqued copy operations to finish. - self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() else: state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False) return state_dict From cecca5ed65f1c44e039f15a25de7fa287d679478 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Tue, 15 Jul 2025 16:05:08 +0800 Subject: [PATCH 453/457] Commit suggestion --- torch/distributed/checkpoint/_experimental/staging.py | 4 ++-- torch/distributed/checkpoint/staging.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/distributed/checkpoint/_experimental/staging.py b/torch/distributed/checkpoint/_experimental/staging.py index 2907f5c33c965..3e57650b58806 100644 --- a/torch/distributed/checkpoint/_experimental/staging.py +++ b/torch/distributed/checkpoint/_experimental/staging.py @@ -82,7 +82,7 @@ class CheckpointStagerConfig: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -162,7 +162,7 @@ def __init__( if self._config.use_non_blocking_copy: assert torch.accelerator.is_available(), ( - "Non-blocking copy requires CUDA/XPU" + "Non-blocking copy requires that the current accelerator is available." ) def stage( diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index e3545600bb4e4..9e1031c7fddae 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -110,7 +110,7 @@ class StagingOptions: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -186,7 +186,7 @@ def __init__( if self._config.use_non_blocking_copy: assert torch.accelerator.is_available(), ( - "Non-blocking copy requires CUDA/XPU" + "Non-blocking copy requires that the current accelerator is available." ) self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None From b80449564ea74d9be2c3089de81a9bed8709d40e Mon Sep 17 00:00:00 2001 From: Chao Han Date: Wed, 23 Jul 2025 09:06:03 +0800 Subject: [PATCH 454/457] Update test/distributed/checkpoint/_experimental/test_staging.py Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- test/distributed/checkpoint/_experimental/test_staging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index f817718fd53b2..a7008a9ec5e4e 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -185,7 +185,7 @@ def test_multiple_staging_operations(self) -> None: use_async_staging=False, use_pinned_memory=torch.cuda.is_available(), use_shared_memory=False, - use_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), ) stager = DefaultStager(options) From 12e06c2ad9d04991204e932734b441b90732f1e8 Mon Sep 17 00:00:00 2001 From: Chao Han Date: Wed, 23 Jul 2025 09:06:13 +0800 Subject: [PATCH 455/457] Update test/distributed/checkpoint/_experimental/test_staging.py Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- test/distributed/checkpoint/_experimental/test_staging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index a7008a9ec5e4e..9595779039d86 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -120,7 +120,7 @@ def test_different_option_combinations(self) -> None: use_pinned_memory=torch.cuda.is_available(), use_shared_memory=False, use_async_staging=False, - use_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), ) ) From adb526100a9f62cf48d398a23083ab288ff4cbf8 Mon Sep 17 00:00:00 2001 From: Chao Han Date: Wed, 23 Jul 2025 09:06:22 +0800 Subject: [PATCH 456/457] Update test/distributed/checkpoint/_experimental/test_builder.py Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- test/distributed/checkpoint/_experimental/test_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py index 3c5210bab9a8e..4f009fd18dd80 100644 --- a/test/distributed/checkpoint/_experimental/test_builder.py +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -123,7 +123,7 @@ def test_make_async_checkpointer(self) -> None: # Create async checkpointer using factory function with default parameters config: CheckpointerConfig = CheckpointerConfig() config.staging_config = CheckpointStagerConfig( - use_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), use_pinned_memory=torch.cuda.is_available(), ) checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info) From 615fb7715e8ab30998cf393090bd34ac5a550c6c Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Thu, 24 Jul 2025 09:26:16 +0800 Subject: [PATCH 457/457] acc comment --- test/distributed/checkpoint/_experimental/test_builder.py | 2 +- test/distributed/checkpoint/_experimental/test_staging.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py index 4f009fd18dd80..788f78892fbbe 100644 --- a/test/distributed/checkpoint/_experimental/test_builder.py +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -124,7 +124,7 @@ def test_make_async_checkpointer(self) -> None: config: CheckpointerConfig = CheckpointerConfig() config.staging_config = CheckpointStagerConfig( use_non_blocking_copy=torch.accelerator.is_available(), - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), ) checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info) diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index 9595779039d86..3fdb3bc022f25 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -108,7 +108,7 @@ def test_different_option_combinations(self) -> None: # Only async staging test_cases.append( CheckpointStagerConfig( - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, use_async_staging=True, use_non_blocking_copy=False, @@ -117,7 +117,7 @@ def test_different_option_combinations(self) -> None: # Only CUDA non-blocking copy test_cases.append( CheckpointStagerConfig( - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, use_async_staging=False, use_non_blocking_copy=torch.accelerator.is_available(), @@ -129,7 +129,7 @@ def test_different_option_combinations(self) -> None: stager = DefaultStager(options) # Test staging works with these options - if options.use_async_staging and torch.cuda.is_available(): + if options.use_async_staging and torch.accelerator.is_available(): result = stager.stage(self.state_dict) self.assertIsInstance(result, Future) staged_dict = result.result() @@ -183,7 +183,7 @@ def test_multiple_staging_operations(self) -> None: """Test multiple staging operations with the same stager.""" options = CheckpointStagerConfig( use_async_staging=False, - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, use_non_blocking_copy=torch.accelerator.is_available(), )