From a9d36ec97c9d92e030a7e66afc32af2fda6dcaa8 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Mon, 30 Jun 2025 15:35:26 +0100 Subject: [PATCH] [PatternMatch][VPlan] Add std::function match overload. NFCI A relatively common use case for PatternMatch is to use match inside all_of/any_of/none_of. This patch adds an overload for match that returns a lambda so callers don't need to create a lambda themselves for both the LLVM and VPlan pattern matchers. --- llvm/include/llvm/IR/PatternMatch.h | 5 +++ llvm/lib/Analysis/InstructionSimplify.cpp | 16 ++++----- llvm/lib/Analysis/ValueTracking.cpp | 5 ++- llvm/lib/CodeGen/InterleavedAccessPass.cpp | 7 ++-- .../InstCombine/InstCombineAddSub.cpp | 8 ++--- .../InstCombine/InstCombineCalls.cpp | 4 +-- .../InstCombine/InstCombineCompares.cpp | 2 +- .../Transforms/InstCombine/InstCombinePHI.cpp | 4 +-- .../InstCombine/InstCombineSelect.cpp | 2 +- llvm/lib/Transforms/Scalar/LICM.cpp | 7 ++-- .../Transforms/Vectorize/LoopVectorize.cpp | 35 +++++++++---------- .../Transforms/Vectorize/SLPVectorizer.cpp | 7 ++-- .../Transforms/Vectorize/VPlanPatternMatch.h | 10 ++++++ 13 files changed, 55 insertions(+), 57 deletions(-) diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 1f86cdfd94e17..e5013c31f1f40 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -50,6 +50,11 @@ template bool match(Val *V, const Pattern &P) { return P.match(V); } +template +std::function match(const Pattern &P) { + return [&P](Val *V) { return P.match(V); }; +} + template bool match(ArrayRef Mask, const Pattern &P) { return P.match(Mask); } diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index cb1dae92faf92..fba9dd80f02c7 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5028,14 +5028,12 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, } // All-zero GEP is a no-op, unless it performs a vector splat. - if (Ptr->getType() == GEPTy && - all_of(Indices, [](const auto *V) { return match(V, m_Zero()); })) + if (Ptr->getType() == GEPTy && all_of(Indices, match(m_Zero()))) return Ptr; // getelementptr poison, idx -> poison // getelementptr baseptr, poison -> poison - if (isa(Ptr) || - any_of(Indices, [](const auto *V) { return isa(V); })) + if (isa(Ptr) || any_of(Indices, match(m_Poison()))) return PoisonValue::get(GEPTy); // getelementptr undef, idx -> undef @@ -5092,8 +5090,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, } if (!IsScalableVec && Q.DL.getTypeAllocSize(LastType) == 1 && - all_of(Indices.drop_back(1), - [](Value *Idx) { return match(Idx, m_Zero()); })) { + all_of(Indices.drop_back(1), match(m_Zero()))) { unsigned IdxWidth = Q.DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()); if (Q.DL.getTypeSizeInBits(Indices.back()->getType()) == IdxWidth) { @@ -5123,8 +5120,7 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, } // Check to see if this is constant foldable. - if (!isa(Ptr) || - !all_of(Indices, [](Value *V) { return isa(V); })) + if (!isa(Ptr) || !all_of(Indices, match(m_Constant()))) return nullptr; if (!ConstantExpr::isSupportedGetElementPtr(SrcTy)) @@ -5649,7 +5645,7 @@ static Constant *simplifyFPOp(ArrayRef Ops, FastMathFlags FMF, RoundingMode Rounding) { // Poison is independent of anything else. It always propagates from an // operand to a math result. - if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); })) + if (any_of(Ops, match(m_Poison()))) return PoisonValue::get(Ops[0]->getType()); for (Value *V : Ops) { @@ -7116,7 +7112,7 @@ static Value *simplifyInstructionWithOperands(Instruction *I, switch (I->getOpcode()) { default: - if (llvm::all_of(NewOps, [](Value *V) { return isa(V); })) { + if (all_of(NewOps, match(m_Constant()))) { SmallVector NewConstOps(NewOps.size()); transform(NewOps, NewConstOps.begin(), [](Value *V) { return cast(V); }); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index e576f4899810a..6cc50bf7e3ee1 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -251,9 +251,8 @@ bool llvm::haveNoCommonBitsSet(const WithCache &LHSCache, } bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) { - return !I->user_empty() && all_of(I->users(), [](const User *U) { - return match(U, m_ICmp(m_Value(), m_Zero())); - }); + return !I->user_empty() && + all_of(I->users(), match(m_ICmp(m_Value(), m_Zero()))); } bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) { diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp index 9c4c86cebe7e5..2d2d48b004e77 100644 --- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp +++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp @@ -294,10 +294,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad( continue; } if (auto *BI = dyn_cast(User)) { - if (!BI->user_empty() && all_of(BI->users(), [](auto *U) { - auto *SVI = dyn_cast(U); - return SVI && isa(SVI->getOperand(1)); - })) { + using namespace PatternMatch; + if (!BI->user_empty() && + all_of(BI->users(), match(m_Shuffle(m_Value(), m_Undef())))) { for (auto *SVI : BI->users()) BinOpShuffles.insert(cast(SVI)); continue; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index e721f0cd5f9e3..cb6dc4b5b0fc5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2307,12 +2307,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't // a pure negation used by a select that looks like abs/nabs. bool IsNegation = match(Op0, m_ZeroInt()); - if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) { - const Instruction *UI = dyn_cast(U); - if (!UI) - return false; - return match(UI, m_c_Select(m_Specific(Op1), m_Specific(&I))); - })) { + if (!IsNegation || + none_of(I.users(), match(m_c_Select(m_Specific(Op1), m_Specific(&I))))) { if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation && I.hasNoSignedWrap(), Op1, *this)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index e33d111167c04..cc6ad9bf44cd2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1418,9 +1418,7 @@ InstCombinerImpl::foldShuffledIntrinsicOperands(IntrinsicInst *II) { // At least 1 operand must be a shuffle with 1 use because we are creating 2 // instructions. - if (none_of(II->args(), [](Value *V) { - return isa(V) && V->hasOneUse(); - })) + if (none_of(II->args(), match(m_OneUse(m_Shuffle(m_Value(), m_Value()))))) return nullptr; // See if all arguments are shuffled with the same mask. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 0894ca92086f3..27b239417de04 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1341,7 +1341,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { return nullptr; if (auto *Phi = dyn_cast(Op0)) - if (all_of(Phi->operands(), [](Value *V) { return isa(V); })) { + if (all_of(Phi->operands(), match(m_Constant()))) { SmallVector Ops; for (Value *V : Phi->incoming_values()) { Constant *Res = diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 6477141ab095f..d992e2f57a0c7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -339,7 +339,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { // convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] ) // Make sure all uses of phi are ptr2int. - if (!all_of(PN.users(), [](User *U) { return isa(U); })) + if (!all_of(PN.users(), match(m_PtrToInt(m_Value())))) return nullptr; // Iterating over all operands to check presence of target pointers for @@ -1298,7 +1298,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // \ / // phi [v1] [v2] // Make sure all inputs are constants. - if (!all_of(PN.operands(), [](Value *V) { return isa(V); })) + if (!all_of(PN.operands(), match(m_ConstantInt()))) return nullptr; BasicBlock *BB = PN.getParent(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 73ba0f78e8053..c43a8cb53e4e9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3142,7 +3142,7 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal, // Profitability check - avoid increasing instruction count. if (none_of(ArrayRef({OuterSelVal.getCondition(), InnerSelVal}), - [](Value *V) { return V->hasOneUse(); })) + match(m_OneUse(m_Value())))) return nullptr; // The appropriate hand of the outermost `select` must be a select itself. diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index cf84366c4200b..c2edc3a33edc1 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -435,10 +435,9 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, // potentially happen in other passes where instructions are being moved // across that edge. bool HasCoroSuspendInst = llvm::any_of(L->getBlocks(), [](BasicBlock *BB) { - return llvm::any_of(*BB, [](Instruction &I) { - IntrinsicInst *II = dyn_cast(&I); - return II && II->getIntrinsicID() == Intrinsic::coro_suspend; - }); + using namespace PatternMatch; + return any_of(make_pointer_range(*BB), + match(m_Intrinsic())); }); MemorySSAUpdater MSSAU(MSSA); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 95479373b4393..3fe9c46ac7656 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7036,11 +7036,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan, // Unused FOR splices are removed by VPlan transforms, so the VPlan-based // cost model won't cost it whilst the legacy will. if (auto *FOR = dyn_cast(&R)) { - if (none_of(FOR->users(), [](VPUser *U) { - auto *VPI = dyn_cast(U); - return VPI && VPI->getOpcode() == - VPInstruction::FirstOrderRecurrenceSplice; - })) + using namespace VPlanPatternMatch; + if (none_of( + FOR->users(), + match( + m_VPInstruction( + m_VPValue(), m_VPValue())))) return true; } // The VPlan-based cost model is more accurate for partial reduction and @@ -7449,13 +7450,11 @@ DenseMap LoopVectorizationPlanner::executePlan( Hints.setAlreadyVectorized(); // Check if it's EVL-vectorized and mark the corresponding metadata. + using namespace VPlanPatternMatch; bool IsEVLVectorized = - llvm::any_of(*HeaderVPBB, [](const VPRecipeBase &Recipe) { - // Looking for the ExplictVectorLength VPInstruction. - if (const auto *VI = dyn_cast(&Recipe)) - return VI->getOpcode() == VPInstruction::ExplicitVectorLength; - return false; - }); + any_of(make_pointer_range(*HeaderVPBB), + match(m_VPInstruction( + m_VPValue()))); if (IsEVLVectorized) { LLVMContext &Context = L->getHeader()->getContext(); MDNode *LoopID = L->getLoopID(); @@ -9737,10 +9736,9 @@ static void preparePlanForMainVectorLoop(VPlan &MainPlan, VPlan &EpiPlan) { // If there is a suitable resume value for the canonical induction in the // scalar (which will become vector) epilogue loop we are done. Otherwise // create it below. - if (any_of(*MainScalarPH, [VectorTC](VPRecipeBase &R) { - return match(&R, m_VPInstruction(m_Specific(VectorTC), - m_SpecificInt(0))); - })) + if (any_of(make_pointer_range(*MainScalarPH), + match(m_VPInstruction(m_Specific(VectorTC), + m_SpecificInt(0))))) return; VPBuilder ScalarPHBuilder(MainScalarPH, MainScalarPH->begin()); ScalarPHBuilder.createScalarPhi( @@ -9778,10 +9776,9 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L, match( P.getIncomingValueForBlock(EPI.MainLoopIterationCountCheck), m_SpecificInt(0)) && - all_of(P.incoming_values(), [&EPI](Value *Inc) { - return Inc == EPI.VectorTripCount || - match(Inc, m_SpecificInt(0)); - })) + all_of(P.incoming_values(), + match(m_CombineOr(m_Specific(EPI.VectorTripCount), + m_SpecificInt(0))))) return &P; return nullptr; }); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 0941bf61953f1..79bf939cd591e 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -20708,10 +20708,9 @@ void BoUpSLP::computeMinimumValueSizes() { IsTruncRoot = true; } bool IsSignedCmp = false; - if (UserIgnoreList && all_of(*UserIgnoreList, [](Value *V) { - return match(V, m_SMin(m_Value(), m_Value())) || - match(V, m_SMax(m_Value(), m_Value())); - })) + if (UserIgnoreList && + all_of(*UserIgnoreList, match(m_CombineOr(m_SMin(m_Value(), m_Value()), + m_SMax(m_Value(), m_Value()))))) IsSignedCmp = true; while (NodeIdx < VectorizableTree.size()) { ArrayRef TreeRoot = VectorizableTree[NodeIdx]->Scalars; diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index efea99f22d086..4aba5fb010559 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -29,11 +29,21 @@ template bool match(Val *V, const Pattern &P) { return P.match(V); } +template +std::function match(const Pattern &P) { + return [&P](Val *V) { return P.match(V); }; +} + template bool match(VPUser *U, const Pattern &P) { auto *R = dyn_cast(U); return R && match(R, P); } +template +std::function match(const Pattern &P) { + return [&P](VPUser *U) { return match(U, P); }; +} + template struct class_match { template bool match(ITy *V) const { return isa(V); } };