diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index ace8583e47b..c9c85bdd88a 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -22,11 +22,10 @@ else echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION" version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" - cuda_toolkit_pckg="cudatoolkit" - if [[ "$CU_VERSION" == cu116 ]]; then - cuda_toolkit_pckg="cuda" + cudatoolkit="nvidia::cudatoolkit=${version}" + if [[ "$version" == "11.6" || "$version" == "11.7" ]]; then + cudatoolkit=" pytorch-cuda=${version}" fi - cudatoolkit="nvidia::${cuda_toolkit_pckg}=${version}" fi case "$(uname -s)" in diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 88b3d910270..ad0f4f94d2f 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -261,8 +261,11 @@ setup_conda_cudatoolkit_constraint() { export CONDA_BUILD_VARIANT="cpu" else case "$CU_VERSION" in + cu117) + export CONDA_CUDATOOLKIT_CONSTRAINT="- pytorch-cuda=11.7 # [not osx]" + ;; cu116) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cuda >=11.6,<11.7 # [not osx]" + export CONDA_CUDATOOLKIT_CONSTRAINT="- pytorch-cuda=11.6 # [not osx]" ;; cu113) export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" @@ -290,8 +293,11 @@ setup_conda_cudatoolkit_plain_constraint() { export CMAKE_USE_CUDA=0 else case "$CU_VERSION" in + cu117) + export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda=11.7" + ;; cu116) - export CONDA_CUDATOOLKIT_CONSTRAINT="cuda=11.6" + export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda=11.6" ;; cu113) export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3"