diff --git a/manywheel/build_cuda.sh b/manywheel/build_cuda.sh index 8a6f10269..318273ba4 100644 --- a/manywheel/build_cuda.sh +++ b/manywheel/build_cuda.sh @@ -262,17 +262,20 @@ else exit 1 fi -# No triton dependency for now on 3.12 since we don't have binaries for it -# and torch.compile doesn't work. -if [[ $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then - TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) - TRITON_REQUIREMENT="triton==${TRITON_VERSION}; platform_system == 'Linux' and platform_machine == 'x86_64'" - if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then - export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}" - else - export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}" - fi +TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) +# Only linux Python < 3.12 are supported wheels for triton +TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.12'" +TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" +if [[ -n "$OVERRIDE_PACKAGE_VERSION" && "$OVERRIDE_PACKAGE_VERSION" =~ .*dev.* ]]; then + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.github/ci_commit_pins/triton.txt) + TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" +fi + +if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}" +else + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}" fi # builder/test.sh requires DESIRED_CUDA to know what tests to exclude