diff --git a/aarch64_linux/aarch64_wheel_ci_build.py b/aarch64_linux/aarch64_wheel_ci_build.py index 61efa10f2..fe612f326 100755 --- a/aarch64_linux/aarch64_wheel_ci_build.py +++ b/aarch64_linux/aarch64_wheel_ci_build.py @@ -218,7 +218,13 @@ def parse_arguments(): version = ( check_output(["cat", "version.txt"], cwd="/pytorch").decode().strip()[:-2] ) - build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 " + if enable_cuda: + desired_cuda = os.getenv("DESIRED_CUDA") + build_vars += ( + f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date}+{desired_cuda} PYTORCH_BUILD_NUMBER=1 " + ) + else: + build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={version}.dev{build_date} PYTORCH_BUILD_NUMBER=1 " elif branch.startswith(("v1.", "v2.")): build_vars += f"BUILD_TEST=0 PYTORCH_BUILD_VERSION={branch[1:branch.find('-')]} PYTORCH_BUILD_NUMBER=1 "