From 632fe58f2b39972659878dd9bf2ac865f75a6880 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Fri, 29 Nov 2024 20:30:45 +0000 Subject: [PATCH 1/2] [SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. A SCEVWrapPredicate A implies B, if * they have the same flag, * both steps are positive and * B's start and step are ULE/SLE (for NSUW/NSSW) than A's. See https://alive2.llvm.org/ce/z/n2T4ss (first pair with known constants as strides, second pair with variable strides). Note that this is limited to steps of the same size, due to NSUW having slightly different semantics than regular NUW. We should be able to remove this restriction for NSSW (which matches NSW) in the future. --- llvm/include/llvm/Analysis/ScalarEvolution.h | 13 +-- llvm/lib/Analysis/ScalarEvolution.cpp | 92 ++++++++++++++----- .../memcheck-wrapping-pointers.ll | 12 +-- .../nssw-predicate-implied.ll | 6 +- 4 files changed, 84 insertions(+), 39 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index de74524c4b6fe..7879622473ad8 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode { virtual bool isAlwaysTrue() const = 0; /// Returns true if this predicate implies \p N. - virtual bool implies(const SCEVPredicate *N) const = 0; + virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0; /// Prints a textual representation of this predicate with an indentation of /// \p Depth. @@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate { const SCEV *LHS, const SCEV *RHS); /// Implementation of the SCEVPredicate interface - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; @@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate { /// Implementation of the SCEVPredicate interface const SCEVAddRecExpr *getExpr() const; - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; @@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate { SmallVector Preds; /// Adds a predicate to this union. - void add(const SCEVPredicate *N); + void add(const SCEVPredicate *N, ScalarEvolution &SE); public: - SCEVUnionPredicate(ArrayRef Preds); + SCEVUnionPredicate(ArrayRef Preds, + ScalarEvolution &SE); ArrayRef getPredicates() const { return Preds; } /// Implementation of the SCEVPredicate interface bool isAlwaysTrue() const override; - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth) const override; /// We estimate the complexity of a union predicate as the size number of diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e18133971f5bf..decf55003033c 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && + !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) && + !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE)) return false; return true; }; @@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(P); + return Pred && Pred->implies(P, SE); } NewPreds->push_back(P); return true; @@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { +bool SCEVComparePredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast(N); if (!Op) @@ -14968,10 +14970,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +bool SCEVWrapPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast(N); + if (!Op) + return false; + + if (setFlags(Flags, Op->Flags) != Flags) + return false; + + if (Op->AR == AR) + return true; + + if (Flags != SCEVWrapPredicate::IncrementNSSW && + Flags != SCEVWrapPredicate::IncrementNUSW) + return false; - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; + bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpStep = Op->AR->getStepRecurrence(SE); + + // If both steps are positive, this implies N, if N's start and step are + // ULE/SLE (for NSUW/NSSW) than this'. + if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) { + const SCEV *OpStart = Op->AR->getStart(); + const SCEV *Start = AR->getStart(); + if (SE.getTypeSizeInBits(Step->getType()) > + SE.getTypeSizeInBits(OpStep->getType())) { + OpStep = SE.getZeroExtendExpr(OpStep, Step->getType()); + } else { + Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType()) + : SE.getNoopOrSignExtend(Step, OpStep->getType()); + } + if (SE.getTypeSizeInBits(Start->getType()) > + SE.getTypeSizeInBits(OpStart->getType())) { + OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType()) + : SE.getSignExtendExpr(OpStart, Start->getType()); + } else { + Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType()) + : SE.getNoopOrSignExtend(Start, OpStart->getType()); + } + + CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; + return SE.isKnownPredicate(Pred, OpStep, Step) && + SE.isKnownPredicate(Pred, OpStart, Start); + } + return false; } bool SCEVWrapPredicate::isAlwaysTrue() const { @@ -15015,10 +15059,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, } /// Union predicates don't get cached so create a dummy set ID for it. -SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef Preds) - : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { +SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef Preds, + ScalarEvolution &SE) + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { for (const auto *P : Preds) - add(P); + add(P, SE); } bool SCEVUnionPredicate::isAlwaysTrue() const { @@ -15026,13 +15071,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const { [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +bool SCEVUnionPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { if (const auto *Set = dyn_cast(N)) - return all_of(Set->Preds, - [this](const SCEVPredicate *I) { return this->implies(I); }); + return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) { + return this->implies(I, SE); + }); return any_of(Preds, - [N](const SCEVPredicate *I) { return I->implies(N); }); + [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }); } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { @@ -15040,15 +15087,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { Pred->print(OS, Depth); } -void SCEVUnionPredicate::add(const SCEVPredicate *N) { +void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) { if (const auto *Set = dyn_cast(N)) { for (const auto *Pred : Set->Preds) - add(Pred); + add(Pred, SE); return; } // Only add predicate if it is not already implied by this union predicate. - if (!implies(N)) + if (!implies(N, SE)) Preds.push_back(N); } @@ -15056,7 +15103,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) : SE(SE), L(L) { SmallVector Empty; - Preds = std::make_unique(Empty); + Preds = std::make_unique(Empty, SE); } void ScalarEvolution::registerUser(const SCEV *User, @@ -15120,12 +15167,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() { } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds->implies(&Pred)) + if (Preds->implies(&Pred, SE)) return; SmallVector NewPreds(Preds->getPredicates()); NewPreds.push_back(&Pred); - Preds = std::make_unique(NewPreds); + Preds = std::make_unique(NewPreds, SE); updateGeneration(); } @@ -15192,9 +15239,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) - : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), - Preds(std::make_unique(Init.Preds->getPredicates())), - Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), + Preds(std::make_unique(Init.Preds->getPredicates(), + SE)), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (auto I : Init.FlagsMap) FlagsMap.insert(I); } diff --git a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll index 6dbb4a0c0129a..ae10ab841420f 100644 --- a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll @@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" ; CHECK-NEXT: Run-time memory checks: ; CHECK-NEXT: Check 0: ; CHECK-NEXT: Comparing group -; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom -; CHECK-NEXT: Against group ; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11 +; CHECK-NEXT: Against group +; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom ; CHECK-NEXT: Grouped accesses: ; CHECK-NEXT: Group -; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a)) -; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body> -; CHECK-NEXT: Group ; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b)) ; CHECK-NEXT: Member: {%b,+,4}<%for.body> +; CHECK-NEXT: Group +; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a)) +; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body> ; CHECK: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: -; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: ; CHECK: Expressions re-written: ; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom: ; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64)) + %a) @@ -85,7 +84,6 @@ exit: ; CHECK: Memory dependences are safe ; CHECK: SCEV assumptions: ; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: -; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: define void @test2(i64 %x, ptr %a) { entry: br label %for.body diff --git a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll index 1a07805c2614f..4f595b44ae5fd 100644 --- a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll @@ -3,7 +3,7 @@ target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32" -; FIXME: {0,+,3} implies {0,+,2}. +; {0,+,3} [nssw] implies {0,+,2} [nssw] define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2' ; CHECK-NEXT: loop: @@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {0,+,3}<%loop> Added Flags: -; CHECK-NEXT: {0,+,2}<%loop> Added Flags: ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2: @@ -59,7 +58,7 @@ exit: ret void } -; FIXME: {2,+,2} implies {0,+,2}. +; {2,+,2} [nssw] implies {0,+,2} [nssw]. define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start' ; CHECK-NEXT: loop: @@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {2,+,2}<%loop> Added Flags: -; CHECK-NEXT: {0,+,2}<%loop> Added Flags: ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2: From 4a10b2b932077fa519b2f7d38d6f0a9258bd7915 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 16 Dec 2024 14:00:28 +0000 Subject: [PATCH 2/2] !fixup adress comments, reorder code --- llvm/lib/Analysis/ScalarEvolution.cpp | 44 ++++++++++----------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index decf55003033c..e2c2500052e7d 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -14973,10 +14973,7 @@ const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } bool SCEVWrapPredicate::implies(const SCEVPredicate *N, ScalarEvolution &SE) const { const auto *Op = dyn_cast(N); - if (!Op) - return false; - - if (setFlags(Flags, Op->Flags) != Flags) + if (!Op || setFlags(Flags, Op->Flags) != Flags) return false; if (Op->AR == AR) @@ -14986,36 +14983,27 @@ bool SCEVWrapPredicate::implies(const SCEVPredicate *N, Flags != SCEVWrapPredicate::IncrementNUSW) return false; - bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; const SCEV *Step = AR->getStepRecurrence(SE); const SCEV *OpStep = Op->AR->getStepRecurrence(SE); + if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep)) + return false; // If both steps are positive, this implies N, if N's start and step are // ULE/SLE (for NSUW/NSSW) than this'. - if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) { - const SCEV *OpStart = Op->AR->getStart(); - const SCEV *Start = AR->getStart(); - if (SE.getTypeSizeInBits(Step->getType()) > - SE.getTypeSizeInBits(OpStep->getType())) { - OpStep = SE.getZeroExtendExpr(OpStep, Step->getType()); - } else { - Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType()) - : SE.getNoopOrSignExtend(Step, OpStep->getType()); - } - if (SE.getTypeSizeInBits(Start->getType()) > - SE.getTypeSizeInBits(OpStart->getType())) { - OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType()) - : SE.getSignExtendExpr(OpStart, Start->getType()); - } else { - Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType()) - : SE.getNoopOrSignExtend(Start, OpStart->getType()); - } + Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType()); + Step = SE.getNoopOrZeroExtend(Step, WiderTy); + OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy); - CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; - return SE.isKnownPredicate(Pred, OpStep, Step) && - SE.isKnownPredicate(Pred, OpStart, Start); - } - return false; + bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; + const SCEV *OpStart = Op->AR->getStart(); + const SCEV *Start = AR->getStart(); + OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy) + : SE.getNoopOrSignExtend(OpStart, WiderTy); + Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy) + : SE.getNoopOrSignExtend(Start, WiderTy); + CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; + return SE.isKnownPredicate(Pred, OpStep, Step) && + SE.isKnownPredicate(Pred, OpStart, Start); } bool SCEVWrapPredicate::isAlwaysTrue() const {