Skip to content

Commit db9d535

Browse files
committed
[LoopVectorize] Vectorize the reduction pattern of integer min/max with index. (2/2)
1 parent 084746e commit db9d535

File tree

12 files changed

+2173
-172
lines changed

12 files changed

+2173
-172
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,9 @@ class RecurrenceDescriptor {
302302
/// Returns the sentinel value for FindLastIV recurrences to replace the start
303303
/// value.
304304
Value *getSentinelValue() const {
305-
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
305+
assert(
306+
(isFindLastIVRecurrenceKind(Kind) || isMinMaxIdxRecurrenceKind(Kind)) &&
307+
"Unexpected recurrence kind");
306308
Type *Ty = StartValue->getType();
307309
return ConstantInt::get(Ty,
308310
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,12 @@ Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
426426
Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src, Value *Start,
427427
const RecurrenceDescriptor &Desc);
428428

429+
/// Create a reduction of the given vector \p Src for a reduction of the
430+
/// kind RecurKind::MinMaxFirstIdx or RecurKind::MinMaxLastIdx. The reduction
431+
/// operation is described by \p Desc.
432+
Value *createMinMaxIdxReduction(IRBuilderBase &B, Value *Src, Value *Start,
433+
const RecurrenceDescriptor &Desc);
434+
429435
/// Create an ordered reduction intrinsic using the given recurrence
430436
/// kind \p RdxKind.
431437
Value *createOrderedReduction(IRBuilderBase &B, RecurKind RdxKind, Value *Src,

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ class LoopVectorizationLegality {
307307
/// Return the fixed-order recurrences found in the loop.
308308
RecurrenceSet &getFixedOrderRecurrences() { return FixedOrderRecurrences; }
309309

310+
/// Return the min/max recurrences found in the loop.
311+
const SmallDenseMap<PHINode *, PHINode *> &getMinMaxRecurrences() {
312+
return MinMaxRecurrences;
313+
}
314+
310315
/// Returns the widest induction type.
311316
IntegerType *getWidestInductionType() { return WidestIndTy; }
312317

@@ -618,7 +623,7 @@ class LoopVectorizationLegality {
618623
RecurrenceSet FixedOrderRecurrences;
619624

620625
/// Holds the min/max recurrences variables.
621-
RecurrenceSet MinMaxRecurrences;
626+
SmallDenseMap<PHINode *, PHINode *> MinMaxRecurrences;
622627

623628
/// Holds the widest induction type encountered.
624629
IntegerType *WidestIndTy = nullptr;

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,25 @@ Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
12581258
return Builder.CreateSelect(Cmp, MaxRdx, Start, "rdx.select");
12591259
}
12601260

1261+
Value *llvm::createMinMaxIdxReduction(IRBuilderBase &Builder, Value *Src,
1262+
Value *Start,
1263+
const RecurrenceDescriptor &Desc) {
1264+
RecurKind Kind = Desc.getRecurrenceKind();
1265+
assert(RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1266+
"Unexpected reduction kind");
1267+
Value *Sentinel = Desc.getSentinelValue();
1268+
Value *Rdx = Src;
1269+
if (Src->getType()->isVectorTy())
1270+
Rdx = Kind == RecurKind::MinMaxFirstIdx
1271+
? Builder.CreateIntMinReduce(Src, true)
1272+
: Builder.CreateIntMaxReduce(Src, true);
1273+
// Correct the final reduction result back to the start value if the reduction
1274+
// result is sentinel value.
1275+
Value *Cmp =
1276+
Builder.CreateCmp(CmpInst::ICMP_NE, Rdx, Sentinel, "rdx.select.cmp");
1277+
return Builder.CreateSelect(Cmp, Rdx, Start, "rdx.select");
1278+
}
1279+
12611280
Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
12621281
FastMathFlags Flags) {
12631282
bool Negative = false;
@@ -1346,7 +1365,8 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
13461365
RecurKind Kind) {
13471366
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
13481367
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
1349-
"AnyOf or FindLastIV reductions are not supported.");
1368+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
1369+
"AnyOf, FindLastIV and MinMaxIdx reductions are not supported.");
13501370
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
13511371
auto *SrcTy = cast<VectorType>(Src->getType());
13521372
Type *SrcEltTy = SrcTy->getElementType();

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
851851
if (MinMaxRecurDes.getLoopExitInstr())
852852
AllowedExit.insert(MinMaxRecurDes.getLoopExitInstr());
853853
Reductions[Phi] = MinMaxRecurDes;
854-
MinMaxRecurrences.insert(Phi);
854+
MinMaxRecurrences.try_emplace(Phi);
855855
MinMaxRecurrenceChains[Phi] = std::move(Chain);
856856
continue;
857857
}
@@ -1093,10 +1093,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
10931093
if (!canVectorizeMinMaxRecurrence(Phi, Chain))
10941094
return false;
10951095
}
1096-
// FIXME: Remove this after the IR generation of min/max with index is
1097-
// supported.
1098-
if (!MinMaxRecurrences.empty())
1099-
return false;
11001096

11011097
return true;
11021098
}
@@ -1106,6 +1102,10 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11061102
assert(!Chain.empty() && "Unexpected empty recurrence chain");
11071103
assert(isMinMaxRecurrence(Phi) && "The PHI is not a min/max recurrence phi");
11081104

1105+
auto It = MinMaxRecurrences.find(Phi);
1106+
if (It->second)
1107+
return true;
1108+
11091109
auto IsMinMaxIdxReductionPhi = [this, Phi, &Chain](Value *Candidate) -> bool {
11101110
auto *IdxPhi = dyn_cast<PHINode>(Candidate);
11111111
if (!IdxPhi || !isReductionVariable(IdxPhi))
@@ -1150,7 +1150,17 @@ bool LoopVectorizationLegality::canVectorizeMinMaxRecurrence(
11501150

11511151
auto *TrueVal = IdxChainHead->getTrueValue();
11521152
auto *FalseVal = IdxChainHead->getFalseValue();
1153-
return IsMinMaxIdxReductionPhi(TrueVal) || IsMinMaxIdxReductionPhi(FalseVal);
1153+
PHINode *IdxPhi;
1154+
if (IsMinMaxIdxReductionPhi(TrueVal))
1155+
IdxPhi = cast<PHINode>(TrueVal);
1156+
else if (IsMinMaxIdxReductionPhi(FalseVal))
1157+
IdxPhi = cast<PHINode>(FalseVal);
1158+
else
1159+
return false;
1160+
1161+
// Record the index reduction phi uses the min/max recurrence.
1162+
It->second = IdxPhi;
1163+
return true;
11541164
}
11551165

11561166
/// Find histogram operations that match high-level code in loops:
@@ -1973,7 +1983,8 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
19731983
SmallPtrSet<const Value *, 8> ReductionLiveOuts;
19741984

19751985
for (const auto &Reduction : getReductionVars())
1976-
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
1986+
if (auto *ExitInstr = Reduction.second.getLoopExitInstr())
1987+
ReductionLiveOuts.insert(ExitInstr);
19771988

19781989
// TODO: handle non-reduction outside users when tail is folded by masking.
19791990
for (auto *AE : AllowedExit) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4484,6 +4484,14 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
44844484
return false;
44854485
}
44864486

4487+
// TODO: support epilogue vectorization for min/max with index.
4488+
if (any_of(Legal->getReductionVars(), [](const auto &Reduction) {
4489+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
4490+
return RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(
4491+
RdxDesc.getRecurrenceKind());
4492+
}))
4493+
return false;
4494+
44874495
// Epilogue vectorization code has not been auditted to ensure it handles
44884496
// non-latch exits properly. It may be fine, but it needs auditted and
44894497
// tested.
@@ -5176,7 +5184,8 @@ LoopVectorizationCostModel::selectInterleaveCount(VPlan &Plan, ElementCount VF,
51765184
const RecurrenceDescriptor &RdxDesc = Reduction.second;
51775185
RecurKind RK = RdxDesc.getRecurrenceKind();
51785186
return RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) ||
5179-
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK);
5187+
RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) ||
5188+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(RK);
51805189
});
51815190
if (HasSelectCmpReductions) {
51825191
LLVM_DEBUG(dbgs() << "LV: Not interleaving select-cmp reductions.\n");
@@ -6893,6 +6902,10 @@ void LoopVectorizationCostModel::collectInLoopReductions() {
68936902

68946903
for (const auto &Reduction : Legal->getReductionVars()) {
68956904
PHINode *Phi = Reduction.first;
6905+
// TODO: support in-loop min/max with index.
6906+
if (Legal->isMinMaxRecurrence(Phi))
6907+
continue;
6908+
68966909
const RecurrenceDescriptor &RdxDesc = Reduction.second;
68976910

68986911
// We don't collect reductions that are type promoted (yet).
@@ -7552,6 +7565,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
75527565
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
75537566
return;
75547567

7568+
assert(EpiRedResult->getOpcode() != VPInstruction::ComputeMinMaxIdxResult);
7569+
75557570
auto *EpiRedHeaderPhi =
75567571
cast<VPReductionPHIRecipe>(EpiRedResult->getOperand(0));
75577572
const RecurrenceDescriptor &RdxDesc =
@@ -8464,10 +8479,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
84648479
// Find all possible partial reductions.
84658480
SmallVector<std::pair<PartialReductionChain, unsigned>>
84668481
PartialReductionChains;
8467-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8468-
getScaledReductions(Phi, RdxDesc.getLoopExitInstr(), Range,
8469-
PartialReductionChains);
8470-
}
8482+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8483+
if (auto *ExitInstr = RdxDesc.getLoopExitInstr())
8484+
getScaledReductions(Phi, ExitInstr, Range, PartialReductionChains);
84718485

84728486
// A partial reduction is invalid if any of its extends are used by
84738487
// something that isn't another partial reduction. This is because the
@@ -8605,8 +8619,9 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
86058619
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
86068620

86078621
// If the PHI is used by a partial reduction, set the scale factor.
8608-
unsigned ScaleFactor =
8609-
getScalingForReduction(RdxDesc.getLoopExitInstr()).value_or(1);
8622+
unsigned ScaleFactor = 1;
8623+
if (auto *ExitInstr = RdxDesc.getLoopExitInstr())
8624+
ScaleFactor = getScalingForReduction(ExitInstr).value_or(1);
86108625
PhiRecipe = new VPReductionPHIRecipe(
86118626
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
86128627
CM.useOrderedReductions(RdxDesc), ScaleFactor);
@@ -9361,6 +9376,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
93619376
assert(
93629377
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
93639378
!RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) &&
9379+
!RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind) &&
93649380
"AnyOf and FindLast reductions are not allowed for in-loop reductions");
93659381

93669382
// Collect the chain of "link" recipes for the reduction starting at PhiR.
@@ -9484,15 +9500,32 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
94849500
PreviousLink = RedRecipe;
94859501
}
94869502
}
9503+
9504+
// Collect all VPReductionPHIRecipes in the header block, and sort them based
9505+
// on the dependency order of the reductions. This ensures that results of
9506+
// min/max reductions are computed before their corresponding index
9507+
// reductions, since the index reduction relies on the result of the min/max
9508+
// reduction to determine which lane produced the min/max.
9509+
SmallVector<VPReductionPHIRecipe *> VPReductionPHIs;
9510+
for (VPRecipeBase &R : Header->phis())
9511+
if (auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R))
9512+
VPReductionPHIs.push_back(PhiR);
9513+
9514+
stable_sort(VPReductionPHIs, [this](const VPReductionPHIRecipe *R1,
9515+
const VPReductionPHIRecipe *R2) {
9516+
auto *Phi1 = cast<PHINode>(R1->getUnderlyingInstr());
9517+
if (!Legal->isMinMaxRecurrence(Phi1))
9518+
return false;
9519+
9520+
auto *Phi2 = cast<PHINode>(R2->getUnderlyingInstr());
9521+
return Legal->getMinMaxRecurrences().find(Phi1)->second == Phi2;
9522+
});
9523+
94879524
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
94889525
Builder.setInsertPoint(&*std::prev(std::prev(LatchVPBB->end())));
94899526
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
9490-
for (VPRecipeBase &R :
9491-
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
9492-
VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
9493-
if (!PhiR)
9494-
continue;
9495-
9527+
SmallDenseMap<VPReductionPHIRecipe *, VPValue *> IdxReductionMasks;
9528+
for (auto *PhiR : VPReductionPHIs) {
94969529
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
94979530
// If tail is folded by masking, introduce selects between the phi
94989531
// and the users outside the vector region of each reduction, at the
@@ -9517,7 +9550,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
95179550
(cast<VPInstruction>(&U)->getOpcode() ==
95189551
VPInstruction::ComputeReductionResult ||
95199552
cast<VPInstruction>(&U)->getOpcode() ==
9520-
VPInstruction::ComputeFindLastIVResult);
9553+
VPInstruction::ComputeFindLastIVResult ||
9554+
cast<VPInstruction>(&U)->getOpcode() ==
9555+
VPInstruction::ComputeMinMaxIdxResult);
95219556
});
95229557
if (CM.usePredicatedReductionSelect())
95239558
PhiR->setOperand(1, NewExitingVPV);
@@ -9562,23 +9597,50 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
95629597
VPInstruction *FinalReductionResult;
95639598
VPBuilder::InsertPointGuard Guard(Builder);
95649599
Builder.setInsertPoint(MiddleVPBB, IP);
9565-
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9566-
RdxDesc.getRecurrenceKind())) {
9600+
RecurKind Kind = RdxDesc.getRecurrenceKind();
9601+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind)) {
95679602
VPValue *Start = PhiR->getStartValue();
95689603
FinalReductionResult =
95699604
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
95709605
{PhiR, Start, NewExitingVPV}, ExitDL);
9606+
} else if (RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind)) {
9607+
// Mask out lanes that cannot be the index of the min/max value.
9608+
VPValue *Mask = IdxReductionMasks.at(PhiR);
9609+
Value *Iden = llvm::getRecurrenceIdentity(
9610+
Kind == RecurKind::MinMaxFirstIdx ? RecurKind::SMin : RecurKind::SMax,
9611+
PhiTy, RdxDesc.getFastMathFlags());
9612+
NewExitingVPV = Builder.createSelect(Mask, NewExitingVPV,
9613+
Plan->getOrAddLiveIn(Iden), ExitDL);
9614+
9615+
VPValue *Start = PhiR->getStartValue();
9616+
FinalReductionResult =
9617+
Builder.createNaryOp(VPInstruction::ComputeMinMaxIdxResult,
9618+
{PhiR, Start, NewExitingVPV}, ExitDL);
95719619
} else {
95729620
FinalReductionResult = Builder.createNaryOp(
95739621
VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL);
95749622
}
95759623
// Update all users outside the vector region.
95769624
OrigExitingVPV->replaceUsesWithIf(
9577-
FinalReductionResult, [FinalReductionResult](VPUser &User, unsigned) {
9625+
FinalReductionResult,
9626+
[FinalReductionResult, NewExitingVPV](VPUser &User, unsigned) {
95789627
auto *Parent = cast<VPRecipeBase>(&User)->getParent();
9579-
return FinalReductionResult != &User && !Parent->getParent();
9628+
return FinalReductionResult != &User &&
9629+
NewExitingVPV->getDefiningRecipe() != &User &&
9630+
!Parent->getParent();
95809631
});
95819632

9633+
// Generate a mask for the index reduction.
9634+
auto *Phi = cast<PHINode>(PhiR->getUnderlyingInstr());
9635+
if (Legal->isMinMaxRecurrence(Phi)) {
9636+
VPValue *IdxRdxMask = Builder.createICmp(CmpInst::ICMP_EQ, NewExitingVPV,
9637+
FinalReductionResult, ExitDL);
9638+
PHINode *IdxPhi = Legal->getMinMaxRecurrences().find(Phi)->second;
9639+
IdxReductionMasks.try_emplace(
9640+
cast<VPReductionPHIRecipe>(RecipeBuilder.getRecipe(IdxPhi)),
9641+
IdxRdxMask);
9642+
}
9643+
95829644
// Adjust AnyOf reductions; replace the reduction phi for the selected value
95839645
// with a boolean reduction phi node to check if the condition is true in
95849646
// any iteration. The final value is selected by the final
@@ -9613,11 +9675,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96139675
continue;
96149676
}
96159677

9616-
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
9617-
RdxDesc.getRecurrenceKind())) {
9618-
// Adjust the start value for FindLastIV recurrences to use the sentinel
9619-
// value after generating the ResumePhi recipe, which uses the original
9620-
// start value.
9678+
if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(Kind) ||
9679+
RecurrenceDescriptor::isMinMaxIdxRecurrenceKind(Kind)) {
9680+
// Adjust the start value for FindLastIV/MinMaxIdx recurrences to use the
9681+
// sentinel value after generating the ResumePhi recipe, which uses the
9682+
// original start value.
96219683
PhiR->setOperand(0, Plan->getOrAddLiveIn(RdxDesc.getSentinelValue()));
96229684
}
96239685
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
899899
BranchOnCond,
900900
Broadcast,
901901
ComputeFindLastIVResult,
902+
ComputeMinMaxIdxResult,
902903
ComputeReductionResult,
903904
// Extracts the last lane from its operand if it is a vector, or the last
904905
// part if scalar. In the latter case, the recipe will be removed during

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8888
"different types inferred for different operands");
8989
return IntegerType::get(Ctx, 1);
9090
case VPInstruction::ComputeFindLastIVResult:
91+
case VPInstruction::ComputeMinMaxIdxResult:
9192
case VPInstruction::ComputeReductionResult: {
9293
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));
9394
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());

0 commit comments

Comments
 (0)