File tree 3 files changed +11
-4
lines changed 3 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -268,7 +268,10 @@ else
268
268
. ./switch_cuda_version.sh " $desired_cuda "
269
269
# TODO, simplify after anaconda fixes their cudatoolkit versioning inconsistency.
270
270
# see: https://github.com/conda-forge/conda-forge.github.io/issues/687#issuecomment-460086164
271
- if [[ " $desired_cuda " == " 12.1" ]]; then
271
+ if [[ " $desired_cuda " == " 12.4" ]]; then
272
+ export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.4,<12.5 # [not osx]"
273
+ export MAGMA_PACKAGE=" - magma-cuda124 # [not osx and not win]"
274
+ elif [[ " $desired_cuda " == " 12.1" ]]; then
272
275
export CONDA_CUDATOOLKIT_CONSTRAINT=" - pytorch-cuda >=12.1,<12.2 # [not osx]"
273
276
export MAGMA_PACKAGE=" - magma-cuda121 # [not osx and not win]"
274
277
elif [[ " $desired_cuda " == " 11.8" ]]; then
Original file line number Diff line number Diff line change @@ -60,10 +60,10 @@ if [[ -n "$build_with_cuda" ]]; then
60
60
TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;3.7+PTX;9.0"
61
61
# for cuda 11.8 include all dynamic loading libraries
62
62
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn* .so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8 /usr/local/cuda/lib64/libcusparseLt.so.0)
63
- elif [[ $CUDA_VERSION == 12.1* ]]; then
63
+ elif [[ $CUDA_VERSION == 12.1* || $CUDA_VERSION == 12.4 * ]]; then
64
64
# cuda 12 does not support sm_3x
65
65
TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;9.0"
66
- # for cuda 12.1 we use cudnn 8.8 and include all dynamic loading libraries
66
+ # for cuda 12.1 (12.4) we use cudnn 8.8 (8.9) and include all dynamic loading libraries
67
67
DEPS_LIST=(/usr/local/cuda/lib64/libcudnn* .so.8 /usr/local/cuda-12.1/extras/CUPTI/lib64/libcupti.so.12 /usr/local/cuda/lib64/libcusparseLt.so.0)
68
68
fi
69
69
if [[ -n " $OVERRIDE_TORCH_CUDA_ARCH_LIST " ]]; then
Original file line number Diff line number Diff line change @@ -59,6 +59,10 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
59
59
60
60
TORCH_CUDA_ARCH_LIST=" 5.0;6.0;7.0;7.5;8.0;8.6"
61
61
case ${CUDA_VERSION} in
62
+ 12.4)
63
+ TORCH_CUDA_ARCH_LIST=" ${TORCH_CUDA_ARCH_LIST} ;9.0"
64
+ EXTRA_CAFFE2_CMAKE_FLAGS+=(" -DATEN_NO_TEST=ON" )
65
+ ;;
62
66
12.1)
63
67
TORCH_CUDA_ARCH_LIST=" ${TORCH_CUDA_ARCH_LIST} ;9.0"
64
68
EXTRA_CAFFE2_CMAKE_FLAGS+=(" -DATEN_NO_TEST=ON" )
@@ -131,7 +135,7 @@ if [[ $USE_CUSPARSELT == "1" ]]; then
131
135
)
132
136
fi
133
137
134
- if [[ $CUDA_VERSION == " 12.1" ]]; then
138
+ if [[ $CUDA_VERSION == " 12.1" || $CUDA_VERSION == " 12.4 " ]]; then
135
139
export USE_STATIC_CUDNN=0
136
140
# Try parallelizing nvcc as well
137
141
export TORCH_NVCC_FLAGS=" -Xfatbin -compress-all --threads 2"
You can’t perform that action at this time.
0 commit comments