@@ -147,13 +147,9 @@ elif [[ $CUDA_VERSION == "11.7" ]]; then
147
147
# Try parallelizing nvcc as well
148
148
export TORCH_NVCC_FLAGS=" -Xfatbin -compress-all --threads 2"
149
149
DEPS_LIST=(
150
- " /usr/local/cuda/lib64/libcudart.so.11.0"
151
- " /usr/local/cuda/lib64/libnvToolsExt.so.1"
152
150
" $LIBGOMP_PATH "
153
151
)
154
152
DEPS_SONAME=(
155
- " libcudart.so.11.0"
156
- " libnvToolsExt.so.1"
157
153
" libgomp.so.1"
158
154
)
159
155
@@ -171,6 +167,8 @@ elif [[ $CUDA_VERSION == "11.7" ]]; then
171
167
" /usr/local/cuda/lib64/libcublasLt.so.11"
172
168
" /usr/local/cuda/lib64/libnvrtc.so.11.2" # this is not a mistake for 11.7, it links to 11.7.50
173
169
" /usr/local/cuda/lib64/libnvrtc-builtins.so.11.7"
170
+ " /usr/local/cuda/lib64/libcudart.so.11.0"
171
+ " /usr/local/cuda/lib64/libnvToolsExt.so.1"
174
172
)
175
173
DEPS_SONAME+=(
176
174
" libcudnn_adv_infer.so.8"
@@ -238,21 +236,33 @@ elif [[ $CUDA_VERSION == "11.8" ]]; then
238
236
" libcublasLt.so.11"
239
237
" libnvrtc.so.11.2"
240
238
" libnvrtc-builtins.so.11.7"
239
+ " libcudart.so.11.0"
240
+ " libnvToolsExt.so.1"
241
241
)
242
242
else
243
- echo " Using cudnn, cublas, nccl, and nvrtc from pypi."
243
+ echo " Using nvidia libs from pypi."
244
244
CUDA_RPATHS=(
245
245
' $ORIGIN/../../nvidia/cublas/lib'
246
+ ' $ORIGIN/../../nvidia/cuda_cupti/lib'
246
247
' $ORIGIN/../../nvidia/cuda_nvrtc/lib'
248
+ ' $ORIGIN/../../nvidia/cuda_runtime/lib'
247
249
' $ORIGIN/../../nvidia/cudnn/lib'
250
+ ' $ORIGIN/../../nvidia/cufft/lib'
251
+ ' $ORIGIN/../../nvidia/curand/lib'
252
+ ' $ORIGIN/../../nvidia/cusolver/lib'
253
+ ' $ORIGIN/../../nvidia/cusparse/lib'
248
254
' $ORIGIN/../../nvidia/nccl/lib'
255
+ ' $ORIGIN/../../nvidia/nvtx/lib'
249
256
)
250
257
CUDA_RPATHS=$( IFS=: ; echo " ${CUDA_RPATHS[*]} " )
251
258
export C_SO_RPATH=$CUDA_RPATHS ' :$ORIGIN:$ORIGIN/lib'
252
259
export LIB_SO_RPATH=$CUDA_RPATHS ' :$ORIGIN'
253
260
export FORCE_RPATH=" --force-rpath"
254
261
export USE_STATIC_NCCL=0
255
262
export USE_SYSTEM_NCCL=1
263
+ export ATEN_STATIC_CUDA=0
264
+ export USE_CUDA_STATIC_LINK=0
265
+ export USE_CUPTI_SO=1
256
266
fi
257
267
else
258
268
echo " Unknown cuda version $CUDA_VERSION "
0 commit comments