@@ -4484,6 +4484,14 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
4484
4484
return false ;
4485
4485
}
4486
4486
4487
+ // TODO: support epilogue vectorization for min/max with index.
4488
+ if (any_of (Legal->getReductionVars (), [](const auto &Reduction) {
4489
+ const RecurrenceDescriptor &RdxDesc = Reduction.second ;
4490
+ return RecurrenceDescriptor::isMinMaxIdxRecurrenceKind (
4491
+ RdxDesc.getRecurrenceKind ());
4492
+ }))
4493
+ return false ;
4494
+
4487
4495
// Epilogue vectorization code has not been auditted to ensure it handles
4488
4496
// non-latch exits properly. It may be fine, but it needs auditted and
4489
4497
// tested.
@@ -5176,7 +5184,8 @@ LoopVectorizationCostModel::selectInterleaveCount(VPlan &Plan, ElementCount VF,
5176
5184
const RecurrenceDescriptor &RdxDesc = Reduction.second ;
5177
5185
RecurKind RK = RdxDesc.getRecurrenceKind ();
5178
5186
return RecurrenceDescriptor::isAnyOfRecurrenceKind (RK) ||
5179
- RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK);
5187
+ RecurrenceDescriptor::isFindLastIVRecurrenceKind (RK) ||
5188
+ RecurrenceDescriptor::isMinMaxIdxRecurrenceKind (RK);
5180
5189
});
5181
5190
if (HasSelectCmpReductions) {
5182
5191
LLVM_DEBUG (dbgs () << " LV: Not interleaving select-cmp reductions.\n " );
@@ -6893,6 +6902,10 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
6893
6902
6894
6903
for (const auto &Reduction : Legal->getReductionVars ()) {
6895
6904
PHINode *Phi = Reduction.first ;
6905
+ // TODO: support in-loop min/max with index.
6906
+ if (Legal->isMinMaxRecurrence (Phi))
6907
+ continue ;
6908
+
6896
6909
const RecurrenceDescriptor &RdxDesc = Reduction.second ;
6897
6910
6898
6911
// We don't collect reductions that are type promoted (yet).
@@ -7552,6 +7565,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
7552
7565
EpiRedResult->getOpcode () != VPInstruction::ComputeFindLastIVResult))
7553
7566
return ;
7554
7567
7568
+ assert (EpiRedResult->getOpcode () != VPInstruction::ComputeMinMaxIdxResult);
7569
+
7555
7570
auto *EpiRedHeaderPhi =
7556
7571
cast<VPReductionPHIRecipe>(EpiRedResult->getOperand (0 ));
7557
7572
const RecurrenceDescriptor &RdxDesc =
@@ -8464,10 +8479,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8464
8479
// Find all possible partial reductions.
8465
8480
SmallVector<std::pair<PartialReductionChain, unsigned >>
8466
8481
PartialReductionChains;
8467
- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8468
- getScaledReductions (Phi, RdxDesc.getLoopExitInstr (), Range,
8469
- PartialReductionChains);
8470
- }
8482
+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8483
+ if (auto *ExitInstr = RdxDesc.getLoopExitInstr ())
8484
+ getScaledReductions (Phi, ExitInstr, Range, PartialReductionChains);
8471
8485
8472
8486
// A partial reduction is invalid if any of its extends are used by
8473
8487
// something that isn't another partial reduction. This is because the
@@ -8605,8 +8619,9 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
8605
8619
Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8606
8620
8607
8621
// If the PHI is used by a partial reduction, set the scale factor.
8608
- unsigned ScaleFactor =
8609
- getScalingForReduction (RdxDesc.getLoopExitInstr ()).value_or (1 );
8622
+ unsigned ScaleFactor = 1 ;
8623
+ if (auto *ExitInstr = RdxDesc.getLoopExitInstr ())
8624
+ ScaleFactor = getScalingForReduction (ExitInstr).value_or (1 );
8610
8625
PhiRecipe = new VPReductionPHIRecipe (
8611
8626
Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8612
8627
CM.useOrderedReductions (RdxDesc), ScaleFactor);
@@ -9361,6 +9376,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9361
9376
assert (
9362
9377
!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
9363
9378
!RecurrenceDescriptor::isFindLastIVRecurrenceKind (Kind) &&
9379
+ !RecurrenceDescriptor::isMinMaxIdxRecurrenceKind (Kind) &&
9364
9380
" AnyOf and FindLast reductions are not allowed for in-loop reductions" );
9365
9381
9366
9382
// Collect the chain of "link" recipes for the reduction starting at PhiR.
@@ -9484,15 +9500,32 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9484
9500
PreviousLink = RedRecipe;
9485
9501
}
9486
9502
}
9503
+
9504
+ // Collect all VPReductionPHIRecipes in the header block, and sort them based
9505
+ // on the dependency order of the reductions. This ensures that results of
9506
+ // min/max reductions are computed before their corresponding index
9507
+ // reductions, since the index reduction relies on the result of the min/max
9508
+ // reduction to determine which lane produced the min/max.
9509
+ SmallVector<VPReductionPHIRecipe *> VPReductionPHIs;
9510
+ for (VPRecipeBase &R : Header->phis ())
9511
+ if (auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R))
9512
+ VPReductionPHIs.push_back (PhiR);
9513
+
9514
+ stable_sort (VPReductionPHIs, [this ](const VPReductionPHIRecipe *R1,
9515
+ const VPReductionPHIRecipe *R2) {
9516
+ auto *Phi1 = cast<PHINode>(R1->getUnderlyingInstr ());
9517
+ if (!Legal->isMinMaxRecurrence (Phi1))
9518
+ return false ;
9519
+
9520
+ auto *Phi2 = cast<PHINode>(R2->getUnderlyingInstr ());
9521
+ return Legal->getMinMaxRecurrences ().find (Phi1)->second == Phi2;
9522
+ });
9523
+
9487
9524
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock ();
9488
9525
Builder.setInsertPoint (&*std::prev (std::prev (LatchVPBB->end ())));
9489
9526
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi ();
9490
- for (VPRecipeBase &R :
9491
- Plan->getVectorLoopRegion ()->getEntryBasicBlock ()->phis ()) {
9492
- VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
9493
- if (!PhiR)
9494
- continue ;
9495
-
9527
+ SmallDenseMap<VPReductionPHIRecipe *, VPValue *> IdxReductionMasks;
9528
+ for (auto *PhiR : VPReductionPHIs) {
9496
9529
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor ();
9497
9530
// If tail is folded by masking, introduce selects between the phi
9498
9531
// and the users outside the vector region of each reduction, at the
@@ -9517,7 +9550,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9517
9550
(cast<VPInstruction>(&U)->getOpcode () ==
9518
9551
VPInstruction::ComputeReductionResult ||
9519
9552
cast<VPInstruction>(&U)->getOpcode () ==
9520
- VPInstruction::ComputeFindLastIVResult);
9553
+ VPInstruction::ComputeFindLastIVResult ||
9554
+ cast<VPInstruction>(&U)->getOpcode () ==
9555
+ VPInstruction::ComputeMinMaxIdxResult);
9521
9556
});
9522
9557
if (CM.usePredicatedReductionSelect ())
9523
9558
PhiR->setOperand (1 , NewExitingVPV);
@@ -9562,23 +9597,50 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9562
9597
VPInstruction *FinalReductionResult;
9563
9598
VPBuilder::InsertPointGuard Guard (Builder);
9564
9599
Builder.setInsertPoint (MiddleVPBB, IP);
9565
- if ( RecurrenceDescriptor::isFindLastIVRecurrenceKind (
9566
- RdxDesc. getRecurrenceKind () )) {
9600
+ RecurKind Kind = RdxDesc. getRecurrenceKind ();
9601
+ if ( RecurrenceDescriptor::isFindLastIVRecurrenceKind (Kind )) {
9567
9602
VPValue *Start = PhiR->getStartValue ();
9568
9603
FinalReductionResult =
9569
9604
Builder.createNaryOp (VPInstruction::ComputeFindLastIVResult,
9570
9605
{PhiR, Start, NewExitingVPV}, ExitDL);
9606
+ } else if (RecurrenceDescriptor::isMinMaxIdxRecurrenceKind (Kind)) {
9607
+ // Mask out lanes that cannot be the index of the min/max value.
9608
+ VPValue *Mask = IdxReductionMasks.at (PhiR);
9609
+ Value *Iden = llvm::getRecurrenceIdentity (
9610
+ Kind == RecurKind::MinMaxFirstIdx ? RecurKind::SMin : RecurKind::SMax,
9611
+ PhiTy, RdxDesc.getFastMathFlags ());
9612
+ NewExitingVPV = Builder.createSelect (Mask, NewExitingVPV,
9613
+ Plan->getOrAddLiveIn (Iden), ExitDL);
9614
+
9615
+ VPValue *Start = PhiR->getStartValue ();
9616
+ FinalReductionResult =
9617
+ Builder.createNaryOp (VPInstruction::ComputeMinMaxIdxResult,
9618
+ {PhiR, Start, NewExitingVPV}, ExitDL);
9571
9619
} else {
9572
9620
FinalReductionResult = Builder.createNaryOp (
9573
9621
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
9574
9622
}
9575
9623
// Update all users outside the vector region.
9576
9624
OrigExitingVPV->replaceUsesWithIf (
9577
- FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned ) {
9625
+ FinalReductionResult,
9626
+ [FinalReductionResult, NewExitingVPV](VPUser &User, unsigned ) {
9578
9627
auto *Parent = cast<VPRecipeBase>(&User)->getParent ();
9579
- return FinalReductionResult != &User && !Parent->getParent ();
9628
+ return FinalReductionResult != &User &&
9629
+ NewExitingVPV->getDefiningRecipe () != &User &&
9630
+ !Parent->getParent ();
9580
9631
});
9581
9632
9633
+ // Generate a mask for the index reduction.
9634
+ auto *Phi = cast<PHINode>(PhiR->getUnderlyingInstr ());
9635
+ if (Legal->isMinMaxRecurrence (Phi)) {
9636
+ VPValue *IdxRdxMask = Builder.createICmp (CmpInst::ICMP_EQ, NewExitingVPV,
9637
+ FinalReductionResult, ExitDL);
9638
+ PHINode *IdxPhi = Legal->getMinMaxRecurrences ().find (Phi)->second ;
9639
+ IdxReductionMasks.try_emplace (
9640
+ cast<VPReductionPHIRecipe>(RecipeBuilder.getRecipe (IdxPhi)),
9641
+ IdxRdxMask);
9642
+ }
9643
+
9582
9644
// Adjust AnyOf reductions; replace the reduction phi for the selected value
9583
9645
// with a boolean reduction phi node to check if the condition is true in
9584
9646
// any iteration. The final value is selected by the final
@@ -9613,11 +9675,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
9613
9675
continue ;
9614
9676
}
9615
9677
9616
- if (RecurrenceDescriptor::isFindLastIVRecurrenceKind (
9617
- RdxDesc. getRecurrenceKind () )) {
9618
- // Adjust the start value for FindLastIV recurrences to use the sentinel
9619
- // value after generating the ResumePhi recipe, which uses the original
9620
- // start value.
9678
+ if (RecurrenceDescriptor::isFindLastIVRecurrenceKind (Kind) ||
9679
+ RecurrenceDescriptor::isMinMaxIdxRecurrenceKind (Kind )) {
9680
+ // Adjust the start value for FindLastIV/MinMaxIdx recurrences to use the
9681
+ // sentinel value after generating the ResumePhi recipe, which uses the
9682
+ // original start value.
9621
9683
PhiR->setOperand (0 , Plan->getOrAddLiveIn (RdxDesc.getSentinelValue ()));
9622
9684
}
9623
9685
}
0 commit comments