diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d1354ccf37660..f5a316d1d8be1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21738,74 +21738,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, SDLoc DL(N); - // The narrower of the two operands. Used as the accumulator - auto NarrowOp = N->getOperand(1); - auto MulOp = N->getOperand(2); - if (MulOp->getOpcode() != ISD::MUL) + SDValue Op2 = N->getOperand(2); + if (Op2->getOpcode() != ISD::MUL || + !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) || + !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode())) return SDValue(); - auto ExtA = MulOp->getOperand(0); - auto ExtB = MulOp->getOperand(1); - - if (!ISD::isExtOpcode(ExtA->getOpcode()) || - !ISD::isExtOpcode(ExtB->getOpcode())) - return SDValue(); - bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND; - bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND; + SDValue Acc = N->getOperand(1); + SDValue Mul = N->getOperand(2); + SDValue ExtMulOpLHS = Mul->getOperand(0); + SDValue ExtMulOpRHS = Mul->getOperand(1); - auto A = ExtA->getOperand(0); - auto B = ExtB->getOperand(0); - if (A.getValueType() != B.getValueType()) + SDValue MulOpLHS = ExtMulOpLHS->getOperand(0); + SDValue MulOpRHS = ExtMulOpRHS->getOperand(0); + if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) return SDValue(); - EVT ReducedType = N->getValueType(0); - EVT MulSrcType = A.getValueType(); + EVT ReducedVT = N->getValueType(0); + EVT MulSrcVT = MulOpLHS.getValueType(); // Dot products operate on chunks of four elements so there must be four times // as many elements in the wide type - if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) && - !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) && - !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) && - !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) && - !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) && - !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8)) + if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) && + !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) && + !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) && + !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) && + !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) && + !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) return SDValue(); + bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND; + bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND; // If the extensions are mixed, we should lower it to a usdot instead unsigned Opcode = 0; - if (AIsSigned != BIsSigned) { + if (MulOpLHSIsSigned != MulOpRHSIsSigned) { if (!Subtarget->hasMatMulInt8()) return SDValue(); bool Scalable = N->getValueType(0).isScalableVT(); // There's no nxv2i64 version of usdot - if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64) + if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64) return SDValue(); Opcode = AArch64ISD::USDOT; // USDOT expects the signed operand to be last - if (!BIsSigned) - std::swap(A, B); - } else if (AIsSigned) + if (!MulOpRHSIsSigned) + std::swap(MulOpLHS, MulOpRHS); + } else if (MulOpLHSIsSigned) Opcode = AArch64ISD::SDOT; else Opcode = AArch64ISD::UDOT; // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot // product followed by a zero / sign extension - if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) || - (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) { - EVT ReducedTypeI32 = - (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32; + if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) || + (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) { + EVT ReducedVTI32 = + (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32; - auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32, - DAG.getConstant(0, DL, ReducedTypeI32), A, B); - auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType); - return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp, - Extended); + SDValue DotI32 = + DAG.getNode(Opcode, DL, ReducedVTI32, + DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS); + SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT); + return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended); } - return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B); + return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS); } SDValue tryLowerPartialReductionToWideAdd(SDNode *N, @@ -21822,32 +21820,29 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N, SDLoc DL(N); - auto Acc = N->getOperand(1); - auto ExtInput = N->getOperand(2); - - EVT AccVT = Acc.getValueType(); - EVT AccElemVT = AccVT.getVectorElementType(); - - if (ExtInput.getValueType().getVectorElementType() != AccElemVT) + if (!ISD::isExtOpcode(N->getOperand(2).getOpcode())) return SDValue(); - - unsigned ExtInputOpcode = ExtInput->getOpcode(); - if (!ISD::isExtOpcode(ExtInputOpcode)) + SDValue Acc = N->getOperand(1); + SDValue Ext = N->getOperand(2); + EVT AccVT = Acc.getValueType(); + EVT ExtVT = Ext.getValueType(); + if (ExtVT.getVectorElementType() != AccVT.getVectorElementType()) return SDValue(); - auto Input = ExtInput->getOperand(0); - EVT InputVT = Input.getValueType(); + SDValue ExtOp = Ext->getOperand(0); + EVT ExtOpVT = ExtOp.getValueType(); - if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) && - !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) && - !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16)) + if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) && + !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) && + !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16)) return SDValue(); - bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND; - auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB; - auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT; - auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input); - return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input); + bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND; + unsigned BottomOpcode = + ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB; + unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT; + SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp); + return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp); } static SDValue performIntrinsicCombine(SDNode *N, @@ -21859,9 +21854,9 @@ static SDValue performIntrinsicCombine(SDNode *N, default: break; case Intrinsic::experimental_vector_partial_reduce_add: { - if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) + if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) return Dot; - if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG)) + if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG)) return WideAdd; return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2));