Skip to content

Commit d0d7058

Browse files
authored
Replacing cudatoolkit by cuda for 11.6 (#5996)
1 parent 12bb887 commit d0d7058

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

.circleci/unittest/linux/scripts/install.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ else
2121
fi
2222
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION"
2323
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
24-
cudatoolkit="nvidia::cudatoolkit=${version}"
24+
25+
cuda_toolkit_pckg="cudatoolkit"
26+
if [[ "$CU_VERSION" == cu116 ]]; then
27+
cuda_toolkit_pckg="cuda"
28+
fi
29+
cudatoolkit="nvidia::${cuda_toolkit_pckg}=${version}"
2530
fi
2631

2732
case "$(uname -s)" in

.circleci/unittest/windows/scripts/install.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ else
2222
elif [[ ${#CU_VERSION} -eq 5 ]]; then
2323
CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
2424
fi
25+
26+
cuda_toolkit_pckg="cudatoolkit"
27+
if [[ "$CU_VERSION" == cu116 ]]; then
28+
cuda_toolkit_pckg="cuda"
29+
fi
30+
2531
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION"
2632
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
27-
cudatoolkit="cudatoolkit=${version}"
33+
cudatoolkit="${cuda_toolkit_pckg}=${version}"
2834
fi
2935

3036
printf "Installing PyTorch with %s\n" "${cudatoolkit}"

packaging/build_conda.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@ setup_conda_pytorch_constraint
1111
setup_conda_cudatoolkit_constraint
1212
setup_visual_studio_constraint
1313
setup_junit_results_folder
14-
15-
# nvidia channel included for cudatoolkit >= 11 however for 11.5 and 11.6 we use conda-forge
1614
export CUDATOOLKIT_CHANNEL="nvidia"
17-
if [[ "$CU_VERSION" == cu116 ]]; then
18-
export CUDATOOLKIT_CHANNEL="conda-forge"
19-
fi
2015

2116
conda build -c $CUDATOOLKIT_CHANNEL -c defaults $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchvision

packaging/pkg_helpers.bash

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ setup_conda_cudatoolkit_constraint() {
257257
else
258258
case "$CU_VERSION" in
259259
cu116)
260-
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.6,<11.7 # [not osx]"
260+
export CONDA_CUDATOOLKIT_CONSTRAINT="- cuda >=11.6,<11.7 # [not osx]"
261261
;;
262262
cu113)
263263
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]"
@@ -286,7 +286,7 @@ setup_conda_cudatoolkit_plain_constraint() {
286286
else
287287
case "$CU_VERSION" in
288288
cu116)
289-
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.6"
289+
export CONDA_CUDATOOLKIT_CONSTRAINT="cuda=11.6"
290290
;;
291291
cu113)
292292
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3"

0 commit comments

Comments
 (0)