Skip to content

[SLP] Support ordered FAdd reductions in SLPVectorizer #146570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 160 additions & 29 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ using namespace std::placeholders;
#define SV_NAME "slp-vectorizer"
#define DEBUG_TYPE "SLP"

STATISTIC(NumFaddVectorized, "Number of vectorized fadd reductions");
STATISTIC(NumVectorInstructions, "Number of vector instructions generated");

DEBUG_COUNTER(VectorizedGraphs, "slp-vectorized",
"Controls which SLP graphs should be vectorized.");

static cl::opt<bool> SLPEnableOrderedFPReductions(
"slp-ordered-fp-reds", cl::init(true), cl::Hidden,
cl::desc("Enable vectorization of ordered floating point reductions"));

static cl::opt<bool>
RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden,
cl::desc("Run the SLP vectorization passes"));
Expand Down Expand Up @@ -1850,6 +1855,11 @@ class BoUpSLP {
return VectorizableTree.front()->Scalars;
}

bool areAllEntriesIdentityOrdered() const {
return all_of(VectorizableTree,
[&](auto &Entry) { return Entry->ReorderIndices.empty(); });
}

/// Returns the type/is-signed info for the root node in the graph without
/// casting.
std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
Expand Down Expand Up @@ -21774,6 +21784,8 @@ class HorizontalReduction {
/// signedness.
SmallVector<std::tuple<Value *, unsigned, bool>> VectorValuesAndScales;

SmallVector<Value *, 2> InitialFAddValues;

static bool isCmpSelMinMax(Instruction *I) {
return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) &&
RecurrenceDescriptor::isMinMaxRecurrenceKind(getRdxKind(I));
Expand All @@ -21787,6 +21799,14 @@ class HorizontalReduction {
(match(I, m_LogicalAnd()) || match(I, m_LogicalOr()));
}

bool isOrderedFaddReduction() const {
if (!isa<Instruction>(ReductionRoot))
return false;
auto *I = cast<Instruction>(ReductionRoot);
return (RdxKind == RecurKind::FAdd) &&
!I->getFastMathFlags().allowReassoc();
}

/// Checks if instruction is associative and can be vectorized.
static bool isVectorizable(RecurKind Kind, Instruction *I) {
if (Kind == RecurKind::None)
Expand All @@ -21807,6 +21827,9 @@ class HorizontalReduction {
if (Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimum)
return true;

if (Kind == RecurKind::FAdd && SLPEnableOrderedFPReductions)
return true;

return I->isAssociative();
}

Expand Down Expand Up @@ -22066,6 +22089,37 @@ class HorizontalReduction {
(I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode()));
}

bool checkOperandsOrder() const {
auto OpsVec = reverse(ReductionOps[0]);
if (!isOrderedFaddReduction() || empty(OpsVec))
return false;
Value *PrevOperand = *OpsVec.begin();
for (auto *I : drop_begin(OpsVec)) {
Value *Op1 = cast<BinaryOperator>(I)->getOperand(0);
if (Op1 != PrevOperand)
return false;
PrevOperand = I;
}
return true;
}

bool checkFastMathFlags() const {
for (auto OpsVec : ReductionOps) {
if (OpsVec.size() <= 1)
continue;
Value *V = *OpsVec.begin();
if (!isa<FPMathOperator>(V))
continue;
bool Flag = cast<Instruction>(V)->getFastMathFlags().allowReassoc();
auto It = find_if(drop_begin(OpsVec), [&](Value *I) {
auto CurFlag = cast<Instruction>(I)->getFastMathFlags().allowReassoc();
return (Flag != CurFlag);
});
if (It != OpsVec.end())
return false;
}
return true;
}
public:
HorizontalReduction() = default;

Expand Down Expand Up @@ -22180,9 +22234,10 @@ class HorizontalReduction {
// Add reduction values. The values are sorted for better vectorization
// results.
for (Value *V : PossibleRedVals) {
size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
/*AllowAlternate=*/false);
size_t Key = 0, Idx = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think better to separate it from the "associative" analysis and implement it in a separate function, in case if associative does not work. It will be much easier to read and maintain. Most of the logic for associative reductions can be dropped

if (!isOrderedFaddReduction())
std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
/*AllowAlternate=*/false);
++PossibleReducedVals[Key][Idx]
.insert(std::make_pair(V, 0))
.first->second;
Expand All @@ -22200,13 +22255,15 @@ class HorizontalReduction {
It != E; ++It) {
PossibleRedValsVect.emplace_back();
auto RedValsVect = It->second.takeVector();
stable_sort(RedValsVect, llvm::less_second());
if (!isOrderedFaddReduction())
stable_sort(RedValsVect, llvm::less_second());
for (const std::pair<Value *, unsigned> &Data : RedValsVect)
PossibleRedValsVect.back().append(Data.second, Data.first);
}
stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
return P1.size() > P2.size();
});
if (!isOrderedFaddReduction())
stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
return P1.size() > P2.size();
});
int NewIdx = -1;
for (ArrayRef<Value *> Data : PossibleRedValsVect) {
if (NewIdx < 0 ||
Expand All @@ -22226,9 +22283,19 @@ class HorizontalReduction {
}
// Sort the reduced values by number of same/alternate opcode and/or pointer
// operand.
stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) {
return P1.size() > P2.size();
});
if (!isOrderedFaddReduction())
stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) {
return P1.size() > P2.size();
});

if (isOrderedFaddReduction() &&
(ReducedVals.size() != 1 || ReducedVals[0].size() == 2 ||
!checkOperandsOrder()))
return false;

if (!checkFastMathFlags())
return false;

return true;
}

Expand Down Expand Up @@ -22423,7 +22490,7 @@ class HorizontalReduction {
// original scalar identity operations on matched horizontal reductions).
IsSupportedHorRdxIdentityOp = RdxKind != RecurKind::Mul &&
RdxKind != RecurKind::FMul &&
RdxKind != RecurKind::FMulAdd;
RdxKind != RecurKind::FMulAdd && !isOrderedFaddReduction();
// Gather same values.
SmallMapVector<Value *, unsigned, 16> SameValuesCounter;
if (IsSupportedHorRdxIdentityOp)
Expand Down Expand Up @@ -22524,6 +22591,8 @@ class HorizontalReduction {
return IsAnyRedOpGathered;
};
bool AnyVectorized = false;
Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);;
Instruction *InsertPt = RdxRootInst;
SmallDenseSet<std::pair<unsigned, unsigned>, 8> IgnoredCandidates;
while (Pos < NumReducedVals - ReduxWidth + 1 &&
ReduxWidth >= ReductionLimit) {
Expand Down Expand Up @@ -22684,8 +22753,6 @@ class HorizontalReduction {

// Emit a reduction. If the root is a select (min/max idiom), the insert
// point is the compare condition of that select.
Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
Instruction *InsertPt = RdxRootInst;
if (IsCmpSelMinMax)
InsertPt = GetCmpForMinMaxReduction(RdxRootInst);

Expand Down Expand Up @@ -22738,6 +22805,41 @@ class HorizontalReduction {
if (!V.isVectorized(RdxVal))
RequiredExtract.insert(RdxVal);
}

auto FirstIt = find_if(ReducedVals[0], [&](Value *RdxVal) {
return VectorizedVals.lookup(RdxVal);
});
auto LastIt = find_if(reverse(ReducedVals[0]), [&](Value *RdxVal) {
return VectorizedVals.lookup(RdxVal);
});
if (isOrderedFaddReduction()) {
//[FirstIt, LastIt] - range of vectorized Vals, we need it to get last
// non-vectorized Val at the beginning and it's ReductionOp and first
// non-vectorized Val at the end and it's ReductinoOp
// fadd - initial value for reduction
// fadd - v
// fadd - v
// fadd - v
// fadd - v
// fadd - scalar remainder
if (LastIt != ReducedVals[0].rend())
ReductionRoot =
cast<Instruction>(ReducedValsToOps.find(*LastIt)->second[0]);

if (InitialFAddValues.empty()) {
auto *FAddBinOp = cast<BinaryOperator>(
ReducedValsToOps.find(*FirstIt)->second[0]);
Value *InitialFAddValue = ConstantExpr::getBinOpIdentity(
FAddBinOp->getOpcode(), FAddBinOp->getType());
if (FirstIt != ReducedVals[0].end()) {
auto *Op1 = FAddBinOp->getOperand(0);
if (!isa<PoisonValue>(Op1))
InitialFAddValue = Op1;
}
InitialFAddValues.push_back(InitialFAddValue);
}
}

Pos += ReduxWidth;
Start = Pos;
ReduxWidth = NumReducedVals - Pos;
Expand All @@ -22755,10 +22857,27 @@ class HorizontalReduction {
continue;
}
}
if (!VectorValuesAndScales.empty())
VectorizedTree = GetNewVectorizedTree(
VectorizedTree,
emitReduction(Builder, *TTI, ReductionRoot->getType()));
if (!VectorValuesAndScales.empty()) {
if (!isOrderedFaddReduction()) {
VectorizedTree = GetNewVectorizedTree(
VectorizedTree,
emitReduction(Builder, *TTI, ReductionRoot->getType()));
} else {
for (auto V : VectorValuesAndScales) {
Value *InitialFAddValue = InitialFAddValues.back();
VectorizedTree = Builder.CreateFAddReduce(InitialFAddValue, std::get<0>(V));
InitialFAddValues.push_back(VectorizedTree);
}
auto LastIt = find_if(reverse(ReducedVals[0]), [&](Value *RdxVal) {
return VectorizedVals.lookup(RdxVal);
});
for_each(reverse(make_range(LastIt.base(), ReducedVals[0].end())),
[&](Value *V) {
ReducedValsToOps.find(V)->second[0]->moveAfter(
cast<Instruction>(VectorizedTree));
});
}
}
if (VectorizedTree) {
// Reorder operands of bool logical op in the natural order to avoid
// possible problem with poison propagation. If not possible to reorder
Expand Down Expand Up @@ -22846,15 +22965,18 @@ class HorizontalReduction {
ExtraReductions.emplace_back(RedOp, RdxVal);
}
}
// Iterate through all not-vectorized reduction values/extra arguments.
bool InitStep = true;
while (ExtraReductions.size() > 1) {
SmallVector<std::pair<Instruction *, Value *>> NewReds =
FinalGen(ExtraReductions, InitStep);
ExtraReductions.swap(NewReds);
InitStep = false;

if (!isOrderedFaddReduction()) {
// Iterate through all not-vectorized reduction values/extra arguments.
bool InitStep = true;
while (ExtraReductions.size() > 1) {
SmallVector<std::pair<Instruction *, Value *>> NewReds =
FinalGen(ExtraReductions, InitStep);
ExtraReductions.swap(NewReds);
InitStep = false;
}
VectorizedTree = ExtraReductions.front().second;
}
VectorizedTree = ExtraReductions.front().second;

ReductionRoot->replaceAllUsesWith(VectorizedTree);

Expand All @@ -22868,21 +22990,28 @@ class HorizontalReduction {
IgnoreSet.insert_range(RdxOps);
#endif
for (ArrayRef<Value *> RdxOps : ReductionOps) {
SmallVector<Value *, 4> RdxOpsForDeletion;
for (Value *Ignore : RdxOps) {
if (!Ignore)
if (!Ignore || (isOrderedFaddReduction() && !Ignore->use_empty() &&
!any_of(cast<Instruction>(Ignore)->operands(),
[](const Value *Val) {
return isa<PoisonValue>(Val);
})))
continue;
#ifndef NDEBUG
for (auto *U : Ignore->users()) {
assert(IgnoreSet.count(U) &&
"All users must be either in the reduction ops list.");
assert((IgnoreSet.count(U) ||
isOrderedFaddReduction()) &&
"All users must be either in the reduction ops list.");
}
#endif
if (!Ignore->use_empty()) {
Value *P = PoisonValue::get(Ignore->getType());
Ignore->replaceAllUsesWith(P);
}
RdxOpsForDeletion.push_back(Ignore);
}
V.removeInstructionsAndOperands(RdxOps, VectorValuesAndScales);
V.removeInstructionsAndOperands(ArrayRef(RdxOpsForDeletion), VectorValuesAndScales);
}
} else if (!CheckForReusedReductionOps) {
for (ReductionOpsType &RdxOps : ReductionOps)
Expand Down Expand Up @@ -22961,6 +23090,8 @@ class HorizontalReduction {
continue;
}
InstructionCost ScalarCost = 0;
if (RdxVal->use_empty())
continue;
for (User *U : RdxVal->users()) {
auto *RdxOp = cast<Instruction>(U);
if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
Expand Down
38 changes: 8 additions & 30 deletions llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,10 @@

define double @dot4f64(ptr dereferenceable(32) %ptrx, ptr dereferenceable(32) %ptry) {
; CHECK-LABEL: @dot4f64(
; CHECK-NEXT: [[PTRX2:%.*]] = getelementptr inbounds double, ptr [[PTRX:%.*]], i64 2
; CHECK-NEXT: [[PTRY2:%.*]] = getelementptr inbounds double, ptr [[PTRY:%.*]], i64 2
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr [[PTRX]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x double>, ptr [[PTRY]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fmul <2 x double> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x double>, ptr [[PTRX2]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = load <2 x double>, ptr [[PTRY2]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = fmul <2 x double> [[TMP4]], [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[TMP3]], i32 0
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[TMP3]], i32 1
; CHECK-NEXT: [[DOT01:%.*]] = fadd double [[TMP7]], [[TMP8]]
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[TMP6]], i32 0
; CHECK-NEXT: [[DOT012:%.*]] = fadd double [[DOT01]], [[TMP9]]
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP6]], i32 1
; CHECK-NEXT: [[DOT0123:%.*]] = fadd double [[DOT012]], [[TMP10]]
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x double>, ptr [[PTRX:%.*]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x double>, ptr [[PTRY:%.*]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fmul <4 x double> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[DOT0123:%.*]] = call double @llvm.vector.reduce.fadd.v4f64(double -0.000000e+00, <4 x double> [[TMP3]])
; CHECK-NEXT: ret double [[DOT0123]]
;
%ptrx1 = getelementptr inbounds double, ptr %ptrx, i64 1
Expand Down Expand Up @@ -53,21 +42,10 @@ define double @dot4f64(ptr dereferenceable(32) %ptrx, ptr dereferenceable(32) %p

define float @dot4f32(ptr dereferenceable(16) %ptrx, ptr dereferenceable(16) %ptry) {
; CHECK-LABEL: @dot4f32(
; CHECK-NEXT: [[PTRX2:%.*]] = getelementptr inbounds float, ptr [[PTRX:%.*]], i64 2
; CHECK-NEXT: [[PTRY2:%.*]] = getelementptr inbounds float, ptr [[PTRY:%.*]], i64 2
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x float>, ptr [[PTRX]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x float>, ptr [[PTRY]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fmul <2 x float> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[PTRX2]], align 4
; CHECK-NEXT: [[TMP5:%.*]] = load <2 x float>, ptr [[PTRY2]], align 4
; CHECK-NEXT: [[TMP6:%.*]] = fmul <2 x float> [[TMP4]], [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
; CHECK-NEXT: [[DOT01:%.*]] = fadd float [[TMP7]], [[TMP8]]
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x float> [[TMP6]], i32 0
; CHECK-NEXT: [[DOT012:%.*]] = fadd float [[DOT01]], [[TMP9]]
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x float> [[TMP6]], i32 1
; CHECK-NEXT: [[DOT0123:%.*]] = fadd float [[DOT012]], [[TMP10]]
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr [[PTRX:%.*]], align 4
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, ptr [[PTRY:%.*]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = fmul <4 x float> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[DOT0123:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP3]])
; CHECK-NEXT: ret float [[DOT0123]]
;
%ptrx1 = getelementptr inbounds float, ptr %ptrx, i64 1
Expand Down
Loading