diff --git a/manywheel/build_rocm.sh b/manywheel/build_rocm.sh index 0c1650f9b..1dee874d2 100755 --- a/manywheel/build_rocm.sh +++ b/manywheel/build_rocm.sh @@ -212,6 +212,18 @@ elif [[ $ROCM_INT -ge 50600 ]]; then DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) fi +# Add triton install dependency +if [[ $(uname) == "Linux" ]]; then + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) + TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) + + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}" + else + export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}" + fi +fi + echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"