|
56 | 56 |
|
57 | 57 | cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
58 | 58 |
|
59 |
| -TORCH_CUDA_ARCH_LIST="3.7;5.0;6.0;7.0" |
| 59 | +TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6" |
60 | 60 | case ${CUDA_VERSION} in
|
| 61 | + 12.1) |
| 62 | + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" |
| 63 | + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") |
| 64 | + ;; |
61 | 65 | 11.8)
|
62 |
| - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};7.5;8.0;8.6;9.0" |
| 66 | + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};3.7;9.0" |
63 | 67 | EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
64 | 68 | ;;
|
65 | 69 | 11.[67])
|
66 |
| - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};7.5;8.0;8.6" |
| 70 | + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};3.7" |
67 | 71 | EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
68 | 72 | ;;
|
69 | 73 | *)
|
@@ -108,7 +112,77 @@ elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then
|
108 | 112 | LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1"
|
109 | 113 | fi
|
110 | 114 |
|
111 |
| -if [[ $CUDA_VERSION == "11.7" || $CUDA_VERSION == "11.8" ]]; then |
| 115 | +if [[ $CUDA_VERSION == "12.1" ]]; then |
| 116 | + export USE_STATIC_CUDNN=0 |
| 117 | + # Try parallelizing nvcc as well |
| 118 | + export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" |
| 119 | + DEPS_LIST=( |
| 120 | + "$LIBGOMP_PATH" |
| 121 | + ) |
| 122 | + DEPS_SONAME=( |
| 123 | + "libgomp.so.1" |
| 124 | + ) |
| 125 | + |
| 126 | + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then |
| 127 | + echo "Bundling with cudnn and cublas." |
| 128 | + DEPS_LIST+=( |
| 129 | + "/usr/local/cuda/lib64/libcudnn_adv_infer.so.8" |
| 130 | + "/usr/local/cuda/lib64/libcudnn_adv_train.so.8" |
| 131 | + "/usr/local/cuda/lib64/libcudnn_cnn_infer.so.8" |
| 132 | + "/usr/local/cuda/lib64/libcudnn_cnn_train.so.8" |
| 133 | + "/usr/local/cuda/lib64/libcudnn_ops_infer.so.8" |
| 134 | + "/usr/local/cuda/lib64/libcudnn_ops_train.so.8" |
| 135 | + "/usr/local/cuda/lib64/libcudnn.so.8" |
| 136 | + "/usr/local/cuda/lib64/libcublas.so.12" |
| 137 | + "/usr/local/cuda/lib64/libcublasLt.so.12" |
| 138 | + "/usr/local/cuda/lib64/libcudart.so.12" |
| 139 | + "/usr/local/cuda/lib64/libnvToolsExt.so.1" |
| 140 | + "/usr/local/cuda/lib64/libnvrtc.so.12" |
| 141 | + "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.1" |
| 142 | + ) |
| 143 | + DEPS_SONAME+=( |
| 144 | + "libcudnn_adv_infer.so.8" |
| 145 | + "libcudnn_adv_train.so.8" |
| 146 | + "libcudnn_cnn_infer.so.8" |
| 147 | + "libcudnn_cnn_train.so.8" |
| 148 | + "libcudnn_ops_infer.so.8" |
| 149 | + "libcudnn_ops_train.so.8" |
| 150 | + "libcudnn.so.8" |
| 151 | + "libcublas.so.12" |
| 152 | + "libcublasLt.so.12" |
| 153 | + "libcudart.so.12" |
| 154 | + "libnvToolsExt.so.1" |
| 155 | + "libnvrtc.so.12" |
| 156 | + "libnvrtc-builtins.so.12.1" |
| 157 | + ) |
| 158 | + else |
| 159 | + echo "Using nvidia libs from pypi." |
| 160 | + CUDA_RPATHS=( |
| 161 | + '$ORIGIN/../../nvidia/cublas/lib' |
| 162 | + '$ORIGIN/../../nvidia/cuda_cupti/lib' |
| 163 | + '$ORIGIN/../../nvidia/cuda_nvrtc/lib' |
| 164 | + '$ORIGIN/../../nvidia/cuda_runtime/lib' |
| 165 | + '$ORIGIN/../../nvidia/cudnn/lib' |
| 166 | + '$ORIGIN/../../nvidia/cufft/lib' |
| 167 | + '$ORIGIN/../../nvidia/curand/lib' |
| 168 | + '$ORIGIN/../../nvidia/cusolver/lib' |
| 169 | + '$ORIGIN/../../nvidia/cusparse/lib' |
| 170 | + '$ORIGIN/../../nvidia/nccl/lib' |
| 171 | + '$ORIGIN/../../nvidia/nvtx/lib' |
| 172 | + ) |
| 173 | + CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") |
| 174 | + export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' |
| 175 | + export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' |
| 176 | + export FORCE_RPATH="--force-rpath" |
| 177 | + export USE_STATIC_NCCL=0 |
| 178 | + export USE_SYSTEM_NCCL=1 |
| 179 | + export ATEN_STATIC_CUDA=0 |
| 180 | + export USE_CUDA_STATIC_LINK=0 |
| 181 | + export USE_CUPTI_SO=1 |
| 182 | + export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" |
| 183 | + export NCCL_LIB_DIR="/usr/local/cuda/lib64/" |
| 184 | + fi |
| 185 | +elif [[ $CUDA_VERSION == "11.7" || $CUDA_VERSION == "11.8" ]]; then |
112 | 186 | export USE_STATIC_CUDNN=0
|
113 | 187 | # Try parallelizing nvcc as well
|
114 | 188 | export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2"
|
|
0 commit comments