@@ -52,31 +52,35 @@ if [[ -z "$USE_CUDA" || "$USE_CUDA" == 1 ]]; then
5252fi
5353if [[ -n " $build_with_cuda " ]]; then
5454 export TORCH_NVCC_FLAGS=" -Xfatbin -compress-all"
55- export TORCH_CUDA_ARCH_LIST=" 3.7+PTX;5.0"
55+ TORCH_CUDA_ARCH_LIST=" 3.7+PTX;5.0"
5656 export USE_STATIC_CUDNN=1 # links cudnn statically (driven by tools/setup_helpers/cudnn.py)
5757
5858 if [[ $CUDA_VERSION == 11.6* ]]; then
59- export TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6"
59+ TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6"
6060 # for cuda 11.5 we use cudnn 8.3.2.44 https://docs.nvidia.com/deeplearning/cudnn/release-notes/rel_8.html
6161 # which does not have single static libcudnn_static.a deliverable to link with
6262 export USE_STATIC_CUDNN=0
6363 # for cuda 11.5 include all dynamic loading libraries
6464 DEPS_LIST=(/usr/local/cuda/lib64/libcudnn* .so.8 /usr/local/cuda-11.6/extras/CUPTI/lib64/libcupti.so.11.6)
6565 elif [[ $CUDA_VERSION == 11.7* ]]; then
66- export TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6"
66+ TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6"
6767 # for cuda 11.7 we use cudnn 8.5
6868 # which does not have single static libcudnn_static.a deliverable to link with
6969 export USE_STATIC_CUDNN=0
7070 # for cuda 11.7 include all dynamic loading libraries
7171 DEPS_LIST=(/usr/local/cuda/lib64/libcudnn* .so.8 /usr/local/cuda-11.7/extras/CUPTI/lib64/libcupti.so.11.7)
7272 elif [[ $CUDA_VERSION == 11.8* ]]; then
73- export TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6;9.0"
73+ TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST ;6.0;6.1;7.0;7.5;8.0;8.6;9.0"
7474 # for cuda 11.8 we use cudnn 8.7
7575 # which does not have single static libcudnn_static.a deliverable to link with
7676 export USE_STATIC_CUDNN=0
7777 # for cuda 11.8 include all dynamic loading libraries
7878 DEPS_LIST=(/usr/local/cuda/lib64/libcudnn* .so.8 /usr/local/cuda-11.8/extras/CUPTI/lib64/libcupti.so.11.8)
7979 fi
80+ if [[ -n " $OVERRIDE_TORCH_CUDA_ARCH_LIST " ]]; then
81+ TORCH_CUDA_ARCH_LIST=" $OVERRIDE_TORCH_CUDA_ARCH_LIST "
82+ fi
83+ export TORCH_CUDA_ARCH_LIST=" $TORCH_CUDA_ARCH_LIST "
8084 export NCCL_ROOT_DIR=/usr/local/cuda
8185 export USE_STATIC_NCCL=1 # links nccl statically (driven by tools/setup_helpers/nccl.py, some of the NCCL cmake files such as FindNCCL.cmake and gloo/FindNCCL.cmake)
8286
0 commit comments