-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[PatternMatch][VPlan] Add std::function match overload. NFCI #146374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
A relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-analysis Author: Luke Lau (lukel97) ChangesA relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers. Full diff: https://github.com/llvm/llvm-project/pull/146374.diff 13 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 1f86cdfd94e17..e5013c31f1f40 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -50,6 +50,11 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val = const Value, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
return P.match(Mask);
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cb1dae92faf92..fba9dd80f02c7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// All-zero GEP is a no-op, unless it performs a vector splat.
- if (Ptr->getType() == GEPTy &&
- all_of(Indices, [](const auto *V) { return match(V, m_Zero()); }))
+ if (Ptr->getType() == GEPTy && all_of(Indices, match(m_Zero())))
return Ptr;
// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
- if (isa<PoisonValue>(Ptr) ||
- any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); }))
+ if (isa<PoisonValue>(Ptr) || any_of(Indices, match(m_Poison())))
return PoisonValue::get(GEPTy);
// getelementptr undef, idx -> undef
@@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 &&
- all_of(Indices.drop_back(1),
- [](Value *Idx) { return match(Idx, m_Zero()); })) {
+ all_of(Indices.drop_back(1), match(m_Zero()))) {
unsigned IdxWidth =
Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace());
if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) {
@@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// Check to see if this is constant foldable.
- if (!isa<Constant>(Ptr) ||
- !all_of(Indices, [](Value *V) { return isa<Constant>(V); }))
+ if (!isa<Constant>(Ptr) || !all_of(Indices, match(m_Constant())))
return nullptr;
if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
@@ -5649,7 +5645,7 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
RoundingMode Rounding) {
// Poison is independent of anything else. It always propagates from an
// operand to a math result.
- if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
+ if (any_of(Ops, match(m_Poison())))
return PoisonValue::get(Ops[0]->getType());
for (Value *V : Ops) {
@@ -7116,7 +7112,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
switch (I->getOpcode()) {
default:
- if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(NewOps, match(m_Constant()))) {
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e576f4899810a..6cc50bf7e3ee1 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -251,9 +251,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
}
bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
- return !I->user_empty() && all_of(I->users(), [](const User *U) {
- return match(U, m_ICmp(m_Value(), m_Zero()));
- });
+ return !I->user_empty() &&
+ all_of(I->users(), match(m_ICmp(m_Value(), m_Zero())));
}
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index 9c4c86cebe7e5..2d2d48b004e77 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -294,10 +294,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
continue;
}
if (auto *BI = dyn_cast<BinaryOperator>(User)) {
- if (!BI->user_empty() && all_of(BI->users(), [](auto *U) {
- auto *SVI = dyn_cast<ShuffleVectorInst>(U);
- return SVI && isa<UndefValue>(SVI->getOperand(1));
- })) {
+ using namespace PatternMatch;
+ if (!BI->user_empty() &&
+ all_of(BI->users(), match(m_Shuffle(m_Value(), m_Undef())))) {
for (auto *SVI : BI->users())
BinOpShuffles.insert(cast<ShuffleVectorInst>(SVI));
continue;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index e721f0cd5f9e3..cb6dc4b5b0fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2307,12 +2307,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
// a pure negation used by a select that looks like abs/nabs.
bool IsNegation = match(Op0, m_ZeroInt());
- if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
- const Instruction *UI = dyn_cast<Instruction>(U);
- if (!UI)
- return false;
- return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
- })) {
+ if (!IsNegation ||
+ none_of(I.users(), match(m_c_Select(m_Specific(Op1), m_Specific(&I))))) {
if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
I.hasNoSignedWrap(),
Op1, *this))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e33d111167c04..cc6ad9bf44cd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1418,9 +1418,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) {
// At least 1 operand must be a shuffle with 1 use because we are creating 2
// instructions.
- if (none_of(II->args(), [](Value *V) {
- return isa<ShuffleVectorInst>(V) && V->hasOneUse();
- }))
+ if (none_of(II->args(), match(m_OneUse(m_Shuffle(m_Value(), m_Value())))))
return nullptr;
// See if all arguments are shuffled with the same mask.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0894ca92086f3..27b239417de04 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1341,7 +1341,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
- if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(Phi->operands(), match(m_Constant()))) {
SmallVector<Constant *> Ops;
for (Value *V : Phi->incoming_values()) {
Constant *Res =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 6477141ab095f..d992e2f57a0c7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -339,7 +339,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
// convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] )
// Make sure all uses of phi are ptr2int.
- if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); }))
+ if (!all_of(PN.users(), match(m_PtrToInt(m_Value()))))
return nullptr;
// Iterating over all operands to check presence of target pointers for
@@ -1298,7 +1298,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// \ /
// phi [v1] [v2]
// Make sure all inputs are constants.
- if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
+ if (!all_of(PN.operands(), match(m_ConstantInt())))
return nullptr;
BasicBlock *BB = PN.getParent();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 73ba0f78e8053..c43a8cb53e4e9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3142,7 +3142,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
// Profitability check - avoid increasing instruction count.
if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
- [](Value *V) { return V->hasOneUse(); }))
+ match(m_OneUse(m_Value()))))
return nullptr;
// The appropriate hand of the outermost `select` must be a select itself.
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index cf84366c4200b..c2edc3a33edc1 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -435,10 +435,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
// potentially happen in other passes where instructions are being moved
// across that edge.
bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) {
- return llvm::any_of(*BB, [](Instruction &I) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- return II && II->getIntrinsicID() == Intrinsic::coro_suspend;
- });
+ using namespace PatternMatch;
+ return any_of(make_pointer_range(*BB),
+ match(m_Intrinsic<Intrinsic::coro_suspend>()));
});
MemorySSAUpdater MSSAU(MSSA);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 95479373b4393..3fe9c46ac7656 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7036,11 +7036,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
// Unused FOR splices are removed by VPlan transforms, so the VPlan-based
// cost model won't cost it whilst the legacy will.
if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) {
- if (none_of(FOR->users(), [](VPUser *U) {
- auto *VPI = dyn_cast<VPInstruction>(U);
- return VPI && VPI->getOpcode() ==
- VPInstruction::FirstOrderRecurrenceSplice;
- }))
+ using namespace VPlanPatternMatch;
+ if (none_of(
+ FOR->users(),
+ match(
+ m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>(
+ m_VPValue(), m_VPValue()))))
return true;
}
// The VPlan-based cost model is more accurate for partial reduction and
@@ -7449,13 +7450,11 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
Hints.setAlreadyVectorized();
// Check if it's EVL-vectorized and mark the corresponding metadata.
+ using namespace VPlanPatternMatch;
bool IsEVLVectorized =
- llvm::any_of(*HeaderVPBB, [](const VPRecipeBase &Recipe) {
- // Looking for the ExplictVectorLength VPInstruction.
- if (const auto *VI = dyn_cast<VPInstruction>(&Recipe))
- return VI->getOpcode() == VPInstruction::ExplicitVectorLength;
- return false;
- });
+ any_of(make_pointer_range(*HeaderVPBB),
+ match(m_VPInstruction<VPInstruction::ExplicitVectorLength>(
+ m_VPValue())));
if (IsEVLVectorized) {
LLVMContext &Context = L->getHeader()->getContext();
MDNode *LoopID = L->getLoopID();
@@ -9737,10 +9736,9 @@ static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) {
// If there is a suitable resume value for the canonical induction in the
// scalar (which will become vector) epilogue loop we are done. Otherwise
// create it below.
- if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) {
- return match(&R, m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
- m_SpecificInt(0)));
- }))
+ if (any_of(make_pointer_range(*MainScalarPH),
+ match(m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
+ m_SpecificInt(0)))))
return;
VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin());
ScalarPHBuilder.createScalarPhi(
@@ -9778,10 +9776,9 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
match(
P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck),
m_SpecificInt(0)) &&
- all_of(P.incoming_values(), [&EPI](Value *Inc) {
- return Inc == EPI.VectorTripCount ||
- match(Inc, m_SpecificInt(0));
- }))
+ all_of(P.incoming_values(),
+ match(m_CombineOr(m_Specific(EPI.VectorTripCount),
+ m_SpecificInt(0)))))
return &P;
return nullptr;
});
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0941bf61953f1..79bf939cd591e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -20708,10 +20708,9 @@ void BoUpSLP::computeMinimumValueSizes() {
IsTruncRoot = true;
}
bool IsSignedCmp = false;
- if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) {
- return match(V, m_SMin(m_Value(), m_Value())) ||
- match(V, m_SMax(m_Value(), m_Value()));
- }))
+ if (UserIgnoreList &&
+ all_of(*UserIgnoreList, match(m_CombineOr(m_SMin(m_Value(), m_Value()),
+ m_SMax(m_Value(), m_Value())))))
IsSignedCmp = true;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index efea99f22d086..4aba5fb010559 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -29,11 +29,21 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
auto *R = dyn_cast<VPRecipeBase>(U);
return R && match(R, P);
}
+template <typename Pattern>
+std::function<bool(VPUser *)> match(const Pattern &P) {
+ return [&P](VPUser *U) { return match(U, P); };
+}
+
template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
|
@llvm/pr-subscribers-llvm-ir Author: Luke Lau (lukel97) ChangesA relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers. Full diff: https://github.com/llvm/llvm-project/pull/146374.diff 13 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 1f86cdfd94e17..e5013c31f1f40 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -50,6 +50,11 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val = const Value, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(ArrayRef<int> Mask, const Pattern &P) {
return P.match(Mask);
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cb1dae92faf92..fba9dd80f02c7 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// All-zero GEP is a no-op, unless it performs a vector splat.
- if (Ptr->getType() == GEPTy &&
- all_of(Indices, [](const auto *V) { return match(V, m_Zero()); }))
+ if (Ptr->getType() == GEPTy && all_of(Indices, match(m_Zero())))
return Ptr;
// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
- if (isa<PoisonValue>(Ptr) ||
- any_of(Indices, [](const auto *V) { return isa<PoisonValue>(V); }))
+ if (isa<PoisonValue>(Ptr) || any_of(Indices, match(m_Poison())))
return PoisonValue::get(GEPTy);
// getelementptr undef, idx -> undef
@@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 &&
- all_of(Indices.drop_back(1),
- [](Value *Idx) { return match(Idx, m_Zero()); })) {
+ all_of(Indices.drop_back(1), match(m_Zero()))) {
unsigned IdxWidth =
Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace());
if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) {
@@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
}
// Check to see if this is constant foldable.
- if (!isa<Constant>(Ptr) ||
- !all_of(Indices, [](Value *V) { return isa<Constant>(V); }))
+ if (!isa<Constant>(Ptr) || !all_of(Indices, match(m_Constant())))
return nullptr;
if (!ConstantExpr::isSupportedGetElementPtr(SrcTy))
@@ -5649,7 +5645,7 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
RoundingMode Rounding) {
// Poison is independent of anything else. It always propagates from an
// operand to a math result.
- if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
+ if (any_of(Ops, match(m_Poison())))
return PoisonValue::get(Ops[0]->getType());
for (Value *V : Ops) {
@@ -7116,7 +7112,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
switch (I->getOpcode()) {
default:
- if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(NewOps, match(m_Constant()))) {
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e576f4899810a..6cc50bf7e3ee1 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -251,9 +251,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
}
bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
- return !I->user_empty() && all_of(I->users(), [](const User *U) {
- return match(U, m_ICmp(m_Value(), m_Zero()));
- });
+ return !I->user_empty() &&
+ all_of(I->users(), match(m_ICmp(m_Value(), m_Zero())));
}
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
index 9c4c86cebe7e5..2d2d48b004e77 100644
--- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp
+++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp
@@ -294,10 +294,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
continue;
}
if (auto *BI = dyn_cast<BinaryOperator>(User)) {
- if (!BI->user_empty() && all_of(BI->users(), [](auto *U) {
- auto *SVI = dyn_cast<ShuffleVectorInst>(U);
- return SVI && isa<UndefValue>(SVI->getOperand(1));
- })) {
+ using namespace PatternMatch;
+ if (!BI->user_empty() &&
+ all_of(BI->users(), match(m_Shuffle(m_Value(), m_Undef())))) {
for (auto *SVI : BI->users())
BinOpShuffles.insert(cast<ShuffleVectorInst>(SVI));
continue;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index e721f0cd5f9e3..cb6dc4b5b0fc5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2307,12 +2307,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
// a pure negation used by a select that looks like abs/nabs.
bool IsNegation = match(Op0, m_ZeroInt());
- if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
- const Instruction *UI = dyn_cast<Instruction>(U);
- if (!UI)
- return false;
- return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I)));
- })) {
+ if (!IsNegation ||
+ none_of(I.users(), match(m_c_Select(m_Specific(Op1), m_Specific(&I))))) {
if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
I.hasNoSignedWrap(),
Op1, *this))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e33d111167c04..cc6ad9bf44cd2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1418,9 +1418,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) {
// At least 1 operand must be a shuffle with 1 use because we are creating 2
// instructions.
- if (none_of(II->args(), [](Value *V) {
- return isa<ShuffleVectorInst>(V) && V->hasOneUse();
- }))
+ if (none_of(II->args(), match(m_OneUse(m_Shuffle(m_Value(), m_Value())))))
return nullptr;
// See if all arguments are shuffled with the same mask.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0894ca92086f3..27b239417de04 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1341,7 +1341,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
- if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
+ if (all_of(Phi->operands(), match(m_Constant()))) {
SmallVector<Constant *> Ops;
for (Value *V : Phi->incoming_values()) {
Constant *Res =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 6477141ab095f..d992e2f57a0c7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -339,7 +339,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) {
// convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] )
// Make sure all uses of phi are ptr2int.
- if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); }))
+ if (!all_of(PN.users(), match(m_PtrToInt(m_Value()))))
return nullptr;
// Iterating over all operands to check presence of target pointers for
@@ -1298,7 +1298,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// \ /
// phi [v1] [v2]
// Make sure all inputs are constants.
- if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
+ if (!all_of(PN.operands(), match(m_ConstantInt())))
return nullptr;
BasicBlock *BB = PN.getParent();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 73ba0f78e8053..c43a8cb53e4e9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3142,7 +3142,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
// Profitability check - avoid increasing instruction count.
if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
- [](Value *V) { return V->hasOneUse(); }))
+ match(m_OneUse(m_Value()))))
return nullptr;
// The appropriate hand of the outermost `select` must be a select itself.
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index cf84366c4200b..c2edc3a33edc1 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -435,10 +435,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI,
// potentially happen in other passes where instructions are being moved
// across that edge.
bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) {
- return llvm::any_of(*BB, [](Instruction &I) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- return II && II->getIntrinsicID() == Intrinsic::coro_suspend;
- });
+ using namespace PatternMatch;
+ return any_of(make_pointer_range(*BB),
+ match(m_Intrinsic<Intrinsic::coro_suspend>()));
});
MemorySSAUpdater MSSAU(MSSA);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 95479373b4393..3fe9c46ac7656 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7036,11 +7036,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
// Unused FOR splices are removed by VPlan transforms, so the VPlan-based
// cost model won't cost it whilst the legacy will.
if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) {
- if (none_of(FOR->users(), [](VPUser *U) {
- auto *VPI = dyn_cast<VPInstruction>(U);
- return VPI && VPI->getOpcode() ==
- VPInstruction::FirstOrderRecurrenceSplice;
- }))
+ using namespace VPlanPatternMatch;
+ if (none_of(
+ FOR->users(),
+ match(
+ m_VPInstruction<VPInstruction::FirstOrderRecurrenceSplice>(
+ m_VPValue(), m_VPValue()))))
return true;
}
// The VPlan-based cost model is more accurate for partial reduction and
@@ -7449,13 +7450,11 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
Hints.setAlreadyVectorized();
// Check if it's EVL-vectorized and mark the corresponding metadata.
+ using namespace VPlanPatternMatch;
bool IsEVLVectorized =
- llvm::any_of(*HeaderVPBB, [](const VPRecipeBase &Recipe) {
- // Looking for the ExplictVectorLength VPInstruction.
- if (const auto *VI = dyn_cast<VPInstruction>(&Recipe))
- return VI->getOpcode() == VPInstruction::ExplicitVectorLength;
- return false;
- });
+ any_of(make_pointer_range(*HeaderVPBB),
+ match(m_VPInstruction<VPInstruction::ExplicitVectorLength>(
+ m_VPValue())));
if (IsEVLVectorized) {
LLVMContext &Context = L->getHeader()->getContext();
MDNode *LoopID = L->getLoopID();
@@ -9737,10 +9736,9 @@ static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) {
// If there is a suitable resume value for the canonical induction in the
// scalar (which will become vector) epilogue loop we are done. Otherwise
// create it below.
- if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) {
- return match(&R, m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
- m_SpecificInt(0)));
- }))
+ if (any_of(make_pointer_range(*MainScalarPH),
+ match(m_VPInstruction<Instruction::PHI>(m_Specific(VectorTC),
+ m_SpecificInt(0)))))
return;
VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin());
ScalarPHBuilder.createScalarPhi(
@@ -9778,10 +9776,9 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
match(
P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck),
m_SpecificInt(0)) &&
- all_of(P.incoming_values(), [&EPI](Value *Inc) {
- return Inc == EPI.VectorTripCount ||
- match(Inc, m_SpecificInt(0));
- }))
+ all_of(P.incoming_values(),
+ match(m_CombineOr(m_Specific(EPI.VectorTripCount),
+ m_SpecificInt(0)))))
return &P;
return nullptr;
});
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0941bf61953f1..79bf939cd591e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -20708,10 +20708,9 @@ void BoUpSLP::computeMinimumValueSizes() {
IsTruncRoot = true;
}
bool IsSignedCmp = false;
- if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) {
- return match(V, m_SMin(m_Value(), m_Value())) ||
- match(V, m_SMax(m_Value(), m_Value()));
- }))
+ if (UserIgnoreList &&
+ all_of(*UserIgnoreList, match(m_CombineOr(m_SMin(m_Value(), m_Value()),
+ m_SMax(m_Value(), m_Value())))))
IsSignedCmp = true;
while (NodeIdx < VectorizableTree.size()) {
ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars;
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index efea99f22d086..4aba5fb010559 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -29,11 +29,21 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) {
return P.match(V);
}
+template <typename Val, typename Pattern>
+std::function<bool(Val *)> match(const Pattern &P) {
+ return [&P](Val *V) { return P.match(V); };
+}
+
template <typename Pattern> bool match(VPUser *U, const Pattern &P) {
auto *R = dyn_cast<VPRecipeBase>(U);
return R && match(R, P);
}
+template <typename Pattern>
+std::function<bool(VPUser *)> match(const Pattern &P) {
+ return [&P](VPUser *U) { return match(U, P); };
+}
+
template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
|
@@ -50,6 +50,11 @@ template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { | |||
return P.match(V); | |||
} | |||
|
|||
template <typename Val = const Value, typename Pattern> | |||
std::function<bool(Val *)> match(const Pattern &P) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be implemented in a way that does not use std::function. You probably need to create a dedicated functor class for this.
I'm also not sure this should be just an overload of match(). Maybe something more explicit like match_fn()?
A relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers.