diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 2854c1462014f..d13770a35c108 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8864,47 +8864,47 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) { } } -// Collect VPIRInstructions for phis in the original exit block that are modeled +// Collect VPIRInstructions for phis in the exit blocks that are modeled // in VPlan and add the exiting VPValue as operand. Some exiting values are not // modeled explicitly yet and won't be included. Those are un-truncated // VPWidenIntOrFpInductionRecipe, VPWidenPointerInductionRecipe and induction // increments. -static SetVector collectUsersInExitBlock( +static SetVector collectUsersInExitBlocks( Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan, const MapVector &Inductions) { - auto *MiddleVPBB = Plan.getMiddleBlock(); - // No edge from the middle block to the unique exit block has been inserted - // and there is nothing to fix from vector loop; phis should have incoming - // from scalar loop only. - if (MiddleVPBB->getNumSuccessors() != 2) - return {}; SetVector ExitUsersToFix; - VPBasicBlock *ExitVPBB = cast(MiddleVPBB->getSuccessors()[0]); - BasicBlock *ExitingBB = OrigLoop->getExitingBlock(); - for (VPRecipeBase &R : *ExitVPBB) { - auto *ExitIRI = dyn_cast(&R); - if (!ExitIRI) - continue; - auto *ExitPhi = dyn_cast(&ExitIRI->getInstruction()); - if (!ExitPhi) - break; - Value *IncomingValue = ExitPhi->getIncomingValueForBlock(ExitingBB); - VPValue *V = Builder.getVPValueOrAddLiveIn(IncomingValue); - // Exit values for inductions are computed and updated outside of VPlan and - // independent of induction recipes. - // TODO: Compute induction exit values in VPlan. - if ((isa(V) && - !cast(V)->getTruncInst()) || - isa(V) || - (isa(IncomingValue) && - OrigLoop->contains(cast(IncomingValue)) && - any_of(IncomingValue->users(), [&Inductions](User *U) { - auto *P = dyn_cast(U); - return P && Inductions.contains(P); - }))) - continue; - ExitUsersToFix.insert(ExitIRI); - ExitIRI->addOperand(V); + for (VPIRBasicBlock *ExitVPBB : Plan.getExitBlocks()) { + BasicBlock *ExitBB = ExitVPBB->getIRBasicBlock(); + BasicBlock *ExitingBB = find_singleton( + to_vector(predecessors(ExitBB)), + [OrigLoop](BasicBlock *Pred, bool AllowRepeats) { + return OrigLoop->contains(Pred) ? Pred : nullptr; + }); + for (VPRecipeBase &R : *ExitVPBB) { + auto *ExitIRI = dyn_cast(&R); + if (!ExitIRI) + continue; + auto *ExitPhi = dyn_cast(&ExitIRI->getInstruction()); + if (!ExitPhi) + break; + Value *IncomingValue = ExitPhi->getIncomingValueForBlock(ExitingBB); + VPValue *V = Builder.getVPValueOrAddLiveIn(IncomingValue); + // Exit values for inductions are computed and updated outside of VPlan + // and independent of induction recipes. + // TODO: Compute induction exit values in VPlan. + if ((isa(V) && + !cast(V)->getTruncInst()) || + isa(V) || + (isa(IncomingValue) && + OrigLoop->contains(cast(IncomingValue)) && + any_of(IncomingValue->users(), [&Inductions](User *U) { + auto *P = dyn_cast(U); + return P && Inductions.contains(P); + }))) + continue; + ExitUsersToFix.insert(ExitIRI); + ExitIRI->addOperand(V); + } } return ExitUsersToFix; } @@ -8912,8 +8912,8 @@ static SetVector collectUsersInExitBlock( // Add exit values to \p Plan. Extracts are added for each entry in \p // ExitUsersToFix if needed and their operands are updated. static void -addUsersInExitBlock(VPlan &Plan, - const SetVector &ExitUsersToFix) { +addUsersInExitBlocks(VPlan &Plan, + const SetVector &ExitUsersToFix) { if (ExitUsersToFix.empty()) return; @@ -8929,6 +8929,8 @@ addUsersInExitBlock(VPlan &Plan, if (V->isLiveIn()) continue; + assert(ExitIRI->getParent()->getSinglePredecessor() == MiddleVPBB && + "Exit value not handled yet for this edge."); LLVMContext &Ctx = ExitIRI->getInstruction().getContext(); VPValue *Ext = B.createNaryOp(VPInstruction::ExtractFromEnd, {V, Plan.getOrAddLiveIn(ConstantInt::get( @@ -9206,10 +9208,10 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { RecipeBuilder.fixHeaderPhis(); addScalarResumePhis(RecipeBuilder, *Plan); - SetVector ExitUsersToFix = collectUsersInExitBlock( + SetVector ExitUsersToFix = collectUsersInExitBlocks( OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars()); addExitUsersForFirstOrderRecurrences(*Plan, ExitUsersToFix); - addUsersInExitBlock(*Plan, ExitUsersToFix); + addUsersInExitBlocks(*Plan, ExitUsersToFix); // --------------------------------------------------------------------------- // Transform initial VPlan: Apply previously taken decisions, in order, to // bring the VPlan to its final state. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 9ef85a7f7a752..70221e7af7dbb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -3839,6 +3839,12 @@ class VPlan { return cast(getVectorLoopRegion()->getSingleSuccessor()); } + /// Return an iterator range over the VPIRBasicBlock wrapping the exit blocks + /// of the VPlan, that is leaf nodes except the scalar header. Defined in + /// VPlanHCFG, as the definition of the type needs access to the definitions + /// of VPBlockShallowTraversalWrapper. + auto getExitBlocks(); + /// The trip count of the original loop. VPValue *getTripCount() const { assert(TripCount && "trip count needs to be set before accessing it"); diff --git a/llvm/lib/Transforms/Vectorize/VPlanCFG.h b/llvm/lib/Transforms/Vectorize/VPlanCFG.h index 89e2e7514dac2..6ca388a953a6f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanCFG.h +++ b/llvm/lib/Transforms/Vectorize/VPlanCFG.h @@ -306,6 +306,15 @@ template <> struct GraphTraits { } }; +inline auto VPlan::getExitBlocks() { + VPBlockBase *ScalarHeader = getScalarHeader(); + return make_filter_range( + VPBlockUtils::blocksOnly( + vp_depth_first_shallow(getVectorLoopRegion()->getSingleSuccessor())), + [ScalarHeader](VPIRBasicBlock *VPIRBB) { + return VPIRBB != ScalarHeader && VPIRBB->getNumSuccessors() == 0; + }); +} } // namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANCFG_H