Skip to content

Commit bb8c42e

Browse files
authored
[LV] Extend FindLastIV to unsigned case (#141752)
Split the FindLastIV RecurKind into SMax and UMax variants, depending on the reduction op produced.
1 parent 6d17eb5 commit bb8c42e

File tree

8 files changed

+513
-137
lines changed

8 files changed

+513
-137
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,12 @@ enum class RecurKind {
5454
FMulAdd, ///< Sum of float products with llvm.fmuladd(a * b + sum).
5555
AnyOf, ///< AnyOf reduction with select(cmp(),x,y) where one of (x,y) is
5656
///< loop invariant, and both x and y are integer type.
57-
FindLastIV, ///< FindLast reduction with select(cmp(),x,y) where one of
58-
///< (x,y) is increasing loop induction, and both x and y are
59-
///< integer type.
57+
FindLastIVSMax, ///< FindLast reduction with select(cmp(),x,y) where one of
58+
///< (x,y) is increasing loop induction, and both x and y
59+
///< are integer type, producing a SMax reduction.
60+
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
61+
///< (x,y) is increasing loop induction, and both x and y
62+
///< are integer type, producing a UMax reduction.
6063
// clang-format on
6164
// TODO: Any_of and FindLast reduction need not be restricted to integer type
6265
// only.
@@ -259,7 +262,14 @@ class RecurrenceDescriptor {
259262
/// Returns true if the recurrence kind is of the form
260263
/// select(cmp(),x,y) where one of (x,y) is increasing loop induction.
261264
static bool isFindLastIVRecurrenceKind(RecurKind Kind) {
262-
return Kind == RecurKind::FindLastIV;
265+
return Kind == RecurKind::FindLastIVSMax ||
266+
Kind == RecurKind::FindLastIVUMax;
267+
}
268+
269+
/// Returns true if recurrece kind is a signed redux kind.
270+
static bool isSignedRecurrenceKind(RecurKind Kind) {
271+
return Kind == RecurKind::SMax || Kind == RecurKind::SMin ||
272+
Kind == RecurKind::FindLastIVSMax;
263273
}
264274

265275
/// Returns the type of the recurrence. This type can be narrower than the
@@ -271,8 +281,10 @@ class RecurrenceDescriptor {
271281
Value *getSentinelValue() const {
272282
assert(isFindLastIVRecurrenceKind(Kind) && "Unexpected recurrence kind");
273283
Type *Ty = StartValue->getType();
274-
return ConstantInt::get(Ty,
275-
APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
284+
unsigned BW = Ty->getIntegerBitWidth();
285+
return ConstantInt::get(Ty, isSignedRecurrenceKind(Kind)
286+
? APInt::getSignedMinValue(BW)
287+
: APInt::getMinValue(BW));
276288
}
277289

278290
/// Returns a reference to the instructions used for type-promoting the

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,8 @@ LLVM_ABI Value *createAnyOfReduction(IRBuilderBase &B, Value *Src,
434434
/// Create a reduction of the given vector \p Src for a reduction of the
435435
/// kind RecurKind::FindLastIV.
436436
LLVM_ABI Value *createFindLastIVReduction(IRBuilderBase &B, Value *Src,
437-
Value *Start, Value *Sentinel);
437+
RecurKind RdxKind, Value *Start,
438+
Value *Sentinel);
438439

439440
/// Create an ordered reduction intrinsic using the given recurrence
440441
/// kind \p RdxKind.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5050
case RecurKind::UMax:
5151
case RecurKind::UMin:
5252
case RecurKind::AnyOf:
53-
case RecurKind::FindLastIV:
53+
case RecurKind::FindLastIVSMax:
54+
case RecurKind::FindLastIVUMax:
5455
return true;
5556
}
5657
return false;
@@ -700,47 +701,59 @@ RecurrenceDescriptor::isFindLastIVPattern(Loop *TheLoop, PHINode *OrigPhi,
700701
m_Value(NonRdxPhi)))))
701702
return InstDesc(false, I);
702703

703-
auto IsIncreasingLoopInduction = [&](Value *V) {
704+
// Returns a non-nullopt boolean indicating the signedness of the recurrence
705+
// when a valid FindLastIV pattern is found.
706+
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
704707
Type *Ty = V->getType();
705708
if (!SE.isSCEVable(Ty))
706-
return false;
709+
return std::nullopt;
707710

708711
auto *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(V));
709712
if (!AR || AR->getLoop() != TheLoop)
710-
return false;
713+
return std::nullopt;
711714

712715
const SCEV *Step = AR->getStepRecurrence(SE);
713716
if (!SE.isKnownPositive(Step))
714-
return false;
717+
return std::nullopt;
715718

716-
const ConstantRange IVRange = SE.getSignedRange(AR);
717-
unsigned NumBits = Ty->getIntegerBitWidth();
718719
// Keep the minimum value of the recurrence type as the sentinel value.
719720
// The maximum acceptable range for the increasing induction variable,
720721
// called the valid range, will be defined as
721722
// [<sentinel value> + 1, <sentinel value>)
722-
// where <sentinel value> is SignedMin(<recurrence type>)
723+
// where <sentinel value> is [Signed|Unsigned]Min(<recurrence type>)
723724
// TODO: This range restriction can be lifted by adding an additional
724725
// virtual OR reduction.
725-
const APInt Sentinel = APInt::getSignedMinValue(NumBits);
726-
const ConstantRange ValidRange =
727-
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
728-
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
729-
<< ", and the signed range of " << *AR << " is "
730-
<< IVRange << "\n");
731-
// Ensure the induction variable does not wrap around by verifying that its
732-
// range is fully contained within the valid range.
733-
return ValidRange.contains(IVRange);
726+
auto CheckRange = [&](bool IsSigned) {
727+
const ConstantRange IVRange =
728+
IsSigned ? SE.getSignedRange(AR) : SE.getUnsignedRange(AR);
729+
unsigned NumBits = Ty->getIntegerBitWidth();
730+
const APInt Sentinel = IsSigned ? APInt::getSignedMinValue(NumBits)
731+
: APInt::getMinValue(NumBits);
732+
const ConstantRange ValidRange =
733+
ConstantRange::getNonEmpty(Sentinel + 1, Sentinel);
734+
LLVM_DEBUG(dbgs() << "LV: FindLastIV valid range is " << ValidRange
735+
<< ", and the range of " << *AR << " is " << IVRange
736+
<< "\n");
737+
738+
// Ensure the induction variable does not wrap around by verifying that
739+
// its range is fully contained within the valid range.
740+
return ValidRange.contains(IVRange);
741+
};
742+
if (CheckRange(true))
743+
return RecurKind::FindLastIVSMax;
744+
if (CheckRange(false))
745+
return RecurKind::FindLastIVUMax;
746+
return std::nullopt;
734747
};
735748

736749
// We are looking for selects of the form:
737750
// select(cmp(), phi, increasing_loop_induction) or
738751
// select(cmp(), increasing_loop_induction, phi)
739752
// TODO: Support for monotonically decreasing induction variable
740-
if (!IsIncreasingLoopInduction(NonRdxPhi))
741-
return InstDesc(false, I);
753+
if (auto RK = GetRecurKind(NonRdxPhi))
754+
return InstDesc(I, *RK);
742755

743-
return InstDesc(I, RecurKind::FindLastIV);
756+
return InstDesc(false, I);
744757
}
745758

746759
RecurrenceDescriptor::InstDesc
@@ -985,8 +998,8 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
985998
<< "\n");
986999
return true;
9871000
}
988-
if (AddReductionVar(Phi, RecurKind::FindLastIV, TheLoop, FMF, RedDes, DB, AC,
989-
DT, SE)) {
1001+
if (AddReductionVar(Phi, RecurKind::FindLastIVSMax, TheLoop, FMF, RedDes, DB,
1002+
AC, DT, SE)) {
9901003
LLVM_DEBUG(dbgs() << "Found a FindLastIV reduction PHI." << *Phi << "\n");
9911004
return true;
9921005
}
@@ -1137,7 +1150,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
11371150
case RecurKind::Mul:
11381151
return Instruction::Mul;
11391152
case RecurKind::AnyOf:
1140-
case RecurKind::FindLastIV:
1153+
case RecurKind::FindLastIVSMax:
1154+
case RecurKind::FindLastIVUMax:
11411155
case RecurKind::Or:
11421156
return Instruction::Or;
11431157
case RecurKind::And:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,9 +1224,11 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
12241224
}
12251225

12261226
Value *llvm::createFindLastIVReduction(IRBuilderBase &Builder, Value *Src,
1227-
Value *Start, Value *Sentinel) {
1227+
RecurKind RdxKind, Value *Start,
1228+
Value *Sentinel) {
1229+
bool IsSigned = RecurrenceDescriptor::isSignedRecurrenceKind(RdxKind);
12281230
Value *MaxRdx = Src->getType()->isVectorTy()
1229-
? Builder.CreateIntMaxReduce(Src, true)
1231+
? Builder.CreateIntMaxReduce(Src, IsSigned)
12301232
: Src;
12311233
// Correct the final reduction result back to the start value if the maximum
12321234
// reduction is sentinel value.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23180,7 +23180,8 @@ class HorizontalReduction {
2318023180
case RecurKind::FMul:
2318123181
case RecurKind::FMulAdd:
2318223182
case RecurKind::AnyOf:
23183-
case RecurKind::FindLastIV:
23183+
case RecurKind::FindLastIVSMax:
23184+
case RecurKind::FindLastIVUMax:
2318423185
case RecurKind::FMaximumNum:
2318523186
case RecurKind::FMinimumNum:
2318623187
case RecurKind::None:
@@ -23314,7 +23315,8 @@ class HorizontalReduction {
2331423315
case RecurKind::FMul:
2331523316
case RecurKind::FMulAdd:
2331623317
case RecurKind::AnyOf:
23317-
case RecurKind::FindLastIV:
23318+
case RecurKind::FindLastIVSMax:
23319+
case RecurKind::FindLastIVUMax:
2331823320
case RecurKind::FMaximumNum:
2331923321
case RecurKind::FMinimumNum:
2332023322
case RecurKind::None:
@@ -23413,7 +23415,8 @@ class HorizontalReduction {
2341323415
case RecurKind::FMul:
2341423416
case RecurKind::FMulAdd:
2341523417
case RecurKind::AnyOf:
23416-
case RecurKind::FindLastIV:
23418+
case RecurKind::FindLastIVSMax:
23419+
case RecurKind::FindLastIVUMax:
2341723420
case RecurKind::FMaximumNum:
2341823421
case RecurKind::FMinimumNum:
2341923422
case RecurKind::None:

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
642642
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
643643
// Get its reduction variable descriptor.
644644
const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor();
645-
[[maybe_unused]] RecurKind RK = RdxDesc.getRecurrenceKind();
645+
RecurKind RK = RdxDesc.getRecurrenceKind();
646646
assert(RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK) &&
647647
"Unexpected reduction kind");
648648
assert(!PhiR->isInLoop() &&
@@ -652,14 +652,17 @@ Value *VPInstruction::generate(VPTransformState &State) {
652652
// sentinel value, followed by one operand for each part of the reduction.
653653
unsigned UF = getNumOperands() - 3;
654654
Value *ReducedPartRdx = State.get(getOperand(3));
655-
for (unsigned Part = 1; Part < UF; ++Part) {
656-
ReducedPartRdx = createMinMaxOp(Builder, RecurKind::SMax, ReducedPartRdx,
655+
RecurKind MinMaxKind = RecurrenceDescriptor::isSignedRecurrenceKind(RK)
656+
? RecurKind::SMax
657+
: RecurKind::UMax;
658+
for (unsigned Part = 1; Part < UF; ++Part)
659+
ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx,
657660
State.get(getOperand(3 + Part)));
658-
}
659661

660662
Value *Start = State.get(getOperand(1), true);
661663
Value *Sentinel = getOperand(2)->getLiveInIRValue();
662-
return createFindLastIVReduction(Builder, ReducedPartRdx, Start, Sentinel);
664+
return createFindLastIVReduction(Builder, ReducedPartRdx, RK, Start,
665+
Sentinel);
663666
}
664667
case VPInstruction::ComputeReductionResult: {
665668
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary

0 commit comments

Comments
 (0)