diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index 193f505fb03fe..416a0a70325d1 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -423,7 +423,7 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src, /// Create a reduction of the given vector \p Src for a reduction of the /// kind RecurKind::IFindLastIV or RecurKind::FFindLastIV. The reduction /// operation is described by \p Desc. -Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, +Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start, const RecurrenceDescriptor &Desc); /// Create an ordered reduction intrinsic using the given recurrence diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 2e7685254f512..f57d95e7722dc 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1233,11 +1233,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src, } Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src, + Value *Start, const RecurrenceDescriptor &Desc) { assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind( Desc.getRecurrenceKind()) && "Unexpected reduction kind"); - Value *StartVal = Desc.getRecurrenceStartValue(); Value *Sentinel = Desc.getSentinelValue(); Value *MaxRdx = Src->getType()->isVectorTy() ? Builder.CreateIntMaxReduce(Src, true) @@ -1246,7 +1246,7 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src, // reduction is sentinel value. Value *Cmp = Builder.CreateCmp(CmpInst::ICMP_NE, MaxRdx, Sentinel, "rdx.select.cmp"); - return Builder.CreateSelect(Cmp, MaxRdx, StartVal, "rdx.select"); + return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select"); } Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 5244a5e7b1c41..fc3e7a4e3d10e 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -9866,14 +9866,19 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( // bc.merge.rdx phi nodes, hence it needs to be created unconditionally here // even for in-loop reductions, until the reduction resume value handling is // also modeled in VPlan. + VPInstruction *FinalReductionResult; VPBuilder::InsertPointGuard Guard(Builder); Builder.setInsertPoint(MiddleVPBB, IP); - auto *FinalReductionResult = - Builder.createNaryOp(RecurrenceDescriptor::isFindLastIVRecurrenceKind( - RdxDesc.getRecurrenceKind()) - ? VPInstruction::ComputeFindLastIVResult - : VPInstruction::ComputeReductionResult, - {PhiR, NewExitingVPV}, ExitDL); + if (RecurrenceDescriptor::isFindLastIVRecurrenceKind( + RdxDesc.getRecurrenceKind())) { + VPValue *Start = PhiR->getStartValue(); + FinalReductionResult = + Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult, + {PhiR, Start, NewExitingVPV}, ExitDL); + } else { + FinalReductionResult = Builder.createNaryOp( + VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL); + } // Update all users outside the vector region. OrigExitingVPV->replaceUsesWithIf( FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index d404ce46fae4a..24a166bd336d1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -51,6 +51,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { switch (Opcode) { case Instruction::ExtractElement: + case Instruction::Freeze: return inferScalarType(R->getOperand(0)); case Instruction::Select: { Type *ResTy = inferScalarType(R->getOperand(1)); diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 8c11d93734667..3a727866a2875 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -216,6 +216,16 @@ using BinaryVPInstruction_match = BinaryRecipe_match; +template +using TernaryRecipe_match = Recipe_match, + Opcode, Commutative, RecipeTys...>; + +template +using TernaryVPInstruction_match = + TernaryRecipe_match; + template using AllBinaryRecipe_match = @@ -234,6 +244,13 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { return BinaryVPInstruction_match(Op0, Op1); } +template +inline TernaryVPInstruction_match +m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { + return TernaryVPInstruction_match( + {Op0, Op1, Op2}); +} + template inline UnaryVPInstruction_match m_Not(const Op0_t &Op0) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index efa238228f6c3..d92417c163a49 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -627,14 +627,15 @@ Value *VPInstruction::generate(VPTransformState &State) { // The recipe's operands are the reduction phi, followed by one operand for // each part of the reduction. - unsigned UF = getNumOperands() - 1; - Value *ReducedPartRdx = State.get(getOperand(1)); + unsigned UF = getNumOperands() - 2; + Value *ReducedPartRdx = State.get(getOperand(2)); for (unsigned Part = 1; Part < UF; ++Part) { ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx, - State.get(getOperand(1 + Part))); + State.get(getOperand(2 + Part))); } - return createFindLastIVReduction(Builder, ReducedPartRdx, RdxDesc); + return createFindLastIVReduction(Builder, ReducedPartRdx, + State.get(getOperand(1), true), RdxDesc); } case VPInstruction::ComputeReductionResult: { // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary @@ -951,6 +952,8 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { return true; case VPInstruction::PtrAdd: return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this); + case VPInstruction::ComputeFindLastIVResult: + return Op == getOperand(1); }; llvm_unreachable("switch should return"); } @@ -1592,7 +1595,6 @@ void VPWidenRecipe::execute(VPTransformState &State) { } case Instruction::Freeze: { Value *Op = State.get(getOperand(0)); - Value *Freeze = Builder.CreateFreeze(Op); State.set(this, Freeze); break; diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp index ad957f33ee699..a513a255344cc 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp @@ -350,7 +350,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) { if (match(&R, m_VPInstruction( m_VPValue(), m_VPValue(Op1))) || match(&R, m_VPInstruction( - m_VPValue(), m_VPValue(Op1)))) { + m_VPValue(), m_VPValue(), m_VPValue(Op1)))) { addUniformForAllParts(cast(&R)); for (unsigned Part = 1; Part != UF; ++Part) R.addOperand(getValueForPart(Op1, Part)); diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll index b357be63a49cd..11b4efb08bb2e 100644 --- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll @@ -234,7 +234,7 @@ define i64 @find_last_iv(ptr %a, i64 %n, i64 %start) { ; CHECK-NEXT: Successor(s): middle.block ; CHECK-EMPTY: ; CHECK-NEXT: middle.block: -; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%cond> +; CHECK-NEXT: EMIT vp<[[RDX_RES:%.+]]> = compute-find-last-iv-result ir<%rdx>, ir<%start>, ir<%cond> ; CHECK-NEXT: EMIT vp<[[EXT:%.+]]> = extract-from-end vp<[[RDX_RES]]>, ir<1> ; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<%n>, vp<{{.+}}> ; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>