Skip to content

Commit 4e10974

Browse files
authored
Enable CUDA 12.4 builds (#1785)
GHA results show this is needed to fix errors in pytorch/pytorch#121684 Reference: #1374
1 parent 87cdc8c commit 4e10974

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

conda/build_pytorch.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ else
268268
. ./switch_cuda_version.sh "$desired_cuda"
269269
# TODO, simplify after anaconda fixes their cudatoolkit versioning inconsistency.
270270
# see: https://github.com/conda-forge/conda-forge.github.io/issues/687#issuecomment-460086164
271-
if [[ "$desired_cuda" == "12.1" ]]; then
271+
if [[ "$desired_cuda" == "12.4" ]]; then
272+
export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.4,<12.5 # [not osx]"
273+
export MAGMA_PACKAGE=" - magma-cuda124 # [not osx and not win]"
274+
elif [[ "$desired_cuda" == "12.1" ]]; then
272275
export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.1,<12.2 # [not osx]"
273276
export MAGMA_PACKAGE=" - magma-cuda121 # [not osx and not win]"
274277
elif [[ "$desired_cuda" == "11.8" ]]; then

conda/pytorch-nightly/build.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ if [[ -n "$build_with_cuda" ]]; then
6060
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;3.7+PTX;9.0"
6161
#for cuda 11.8 include all dynamic loading libraries
6262
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8 /usr/local/cuda/lib64/libcusparseLt.so.0)
63-
elif [[ $CUDA_VERSION == 12.1* ]]; then
63+
elif [[ $CUDA_VERSION == 12.1* || $CUDA_VERSION == 12.4* ]]; then
6464
# cuda 12 does not support sm_3x
6565
TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST;9.0"
66-
# for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries
66+
# for cuda 12.1 (12.4) we use cudnn 8.8 (8.9) and include all dynamic loading libraries
6767
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn*.so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12 /usr/local/cuda/lib64/libcusparseLt.so.0)
6868
fi
6969
if [[ -n "$OVERRIDE_TORCH_CUDA_ARCH_LIST" ]]; then

manywheel/build_cuda.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
5959

6060
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6"
6161
case ${CUDA_VERSION} in
62+
12.4)
63+
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
64+
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
65+
;;
6266
12.1)
6367
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
6468
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
@@ -131,7 +135,7 @@ if [[ $USE_CUSPARSELT == "1" ]]; then
131135
)
132136
fi
133137

134-
if [[ $CUDA_VERSION == "12.1" ]]; then
138+
if [[ $CUDA_VERSION == "12.1" || $CUDA_VERSION == "12.4" ]]; then
135139
export USE_STATIC_CUDNN=0
136140
# Try parallelizing nvcc as well
137141
export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2"

0 commit comments

Comments
 (0)