diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index de74524c4b6fe..7879622473ad8 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -241,7 +241,7 @@ class SCEVPredicate : public FoldingSetNode { virtual bool isAlwaysTrue() const = 0; /// Returns true if this predicate implies \p N. - virtual bool implies(const SCEVPredicate *N) const = 0; + virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const = 0; /// Prints a textual representation of this predicate with an indentation of /// \p Depth. @@ -286,7 +286,7 @@ class SCEVComparePredicate final : public SCEVPredicate { const SCEV *LHS, const SCEV *RHS); /// Implementation of the SCEVPredicate interface - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; @@ -393,7 +393,7 @@ class SCEVWrapPredicate final : public SCEVPredicate { /// Implementation of the SCEVPredicate interface const SCEVAddRecExpr *getExpr() const; - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth = 0) const override; bool isAlwaysTrue() const override; @@ -418,16 +418,17 @@ class SCEVUnionPredicate final : public SCEVPredicate { SmallVector Preds; /// Adds a predicate to this union. - void add(const SCEVPredicate *N); + void add(const SCEVPredicate *N, ScalarEvolution &SE); public: - SCEVUnionPredicate(ArrayRef Preds); + SCEVUnionPredicate(ArrayRef Preds, + ScalarEvolution &SE); ArrayRef getPredicates() const { return Preds; } /// Implementation of the SCEVPredicate interface bool isAlwaysTrue() const override; - bool implies(const SCEVPredicate *N) const override; + bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override; void print(raw_ostream &OS, unsigned Depth) const override; /// We estimate the complexity of a union predicate as the size number of diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e18133971f5bf..e2c2500052e7d 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( return true; auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { - if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && - !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) + if (Expr1 != Expr2 && + !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) && + !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE)) return false; return true; }; @@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. - return Pred && Pred->implies(P); + return Pred && Pred->implies(P, SE); } NewPreds->push_back(P); return true; @@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { +bool SCEVComparePredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast(N); if (!Op) @@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; } -bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +bool SCEVWrapPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { const auto *Op = dyn_cast(N); + if (!Op || setFlags(Flags, Op->Flags) != Flags) + return false; + + if (Op->AR == AR) + return true; + + if (Flags != SCEVWrapPredicate::IncrementNSSW && + Flags != SCEVWrapPredicate::IncrementNUSW) + return false; - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *OpStep = Op->AR->getStepRecurrence(SE); + if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep)) + return false; + + // If both steps are positive, this implies N, if N's start and step are + // ULE/SLE (for NSUW/NSSW) than this'. + Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType()); + Step = SE.getNoopOrZeroExtend(Step, WiderTy); + OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy); + + bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW; + const SCEV *OpStart = Op->AR->getStart(); + const SCEV *Start = AR->getStart(); + OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy) + : SE.getNoopOrSignExtend(OpStart, WiderTy); + Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy) + : SE.getNoopOrSignExtend(Start, WiderTy); + CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE; + return SE.isKnownPredicate(Pred, OpStep, Step) && + SE.isKnownPredicate(Pred, OpStart, Start); } bool SCEVWrapPredicate::isAlwaysTrue() const { @@ -15015,10 +15047,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, } /// Union predicates don't get cached so create a dummy set ID for it. -SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef Preds) - : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { +SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef Preds, + ScalarEvolution &SE) + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) { for (const auto *P : Preds) - add(P); + add(P, SE); } bool SCEVUnionPredicate::isAlwaysTrue() const { @@ -15026,13 +15059,15 @@ bool SCEVUnionPredicate::isAlwaysTrue() const { [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); } -bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +bool SCEVUnionPredicate::implies(const SCEVPredicate *N, + ScalarEvolution &SE) const { if (const auto *Set = dyn_cast(N)) - return all_of(Set->Preds, - [this](const SCEVPredicate *I) { return this->implies(I); }); + return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) { + return this->implies(I, SE); + }); return any_of(Preds, - [N](const SCEVPredicate *I) { return I->implies(N); }); + [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); }); } void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { @@ -15040,15 +15075,15 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { Pred->print(OS, Depth); } -void SCEVUnionPredicate::add(const SCEVPredicate *N) { +void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) { if (const auto *Set = dyn_cast(N)) { for (const auto *Pred : Set->Preds) - add(Pred); + add(Pred, SE); return; } // Only add predicate if it is not already implied by this union predicate. - if (!implies(N)) + if (!implies(N, SE)) Preds.push_back(N); } @@ -15056,7 +15091,7 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) : SE(SE), L(L) { SmallVector Empty; - Preds = std::make_unique(Empty); + Preds = std::make_unique(Empty, SE); } void ScalarEvolution::registerUser(const SCEV *User, @@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() { } void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { - if (Preds->implies(&Pred)) + if (Preds->implies(&Pred, SE)) return; SmallVector NewPreds(Preds->getPredicates()); NewPreds.push_back(&Pred); - Preds = std::make_unique(NewPreds); + Preds = std::make_unique(NewPreds, SE); updateGeneration(); } @@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { PredicatedScalarEvolution::PredicatedScalarEvolution( const PredicatedScalarEvolution &Init) - : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), - Preds(std::make_unique(Init.Preds->getPredicates())), - Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), + Preds(std::make_unique(Init.Preds->getPredicates(), + SE)), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { for (auto I : Init.FlagsMap) FlagsMap.insert(I); } diff --git a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll index 6dbb4a0c0129a..ae10ab841420f 100644 --- a/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/memcheck-wrapping-pointers.ll @@ -29,20 +29,19 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" ; CHECK-NEXT: Run-time memory checks: ; CHECK-NEXT: Check 0: ; CHECK-NEXT: Comparing group -; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom -; CHECK-NEXT: Against group ; CHECK-NEXT: %arrayidx4 = getelementptr inbounds i32, ptr %b, i64 %conv11 +; CHECK-NEXT: Against group +; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom ; CHECK-NEXT: Grouped accesses: ; CHECK-NEXT: Group -; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a)) -; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body> -; CHECK-NEXT: Group ; CHECK-NEXT: (Low: %b High: ((4 * (1 umax %x)) + %b)) ; CHECK-NEXT: Member: {%b,+,4}<%for.body> +; CHECK-NEXT: Group +; CHECK-NEXT: (Low: (4 + %a) High: (4 + (4 * (1 umax %x)) + %a)) +; CHECK-NEXT: Member: {(4 + %a),+,4}<%for.body> ; CHECK: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: -; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: ; CHECK: Expressions re-written: ; CHECK-NEXT: [PSE] %arrayidx = getelementptr inbounds i32, ptr %a, i64 %idxprom: ; CHECK-NEXT: ((4 * (zext i32 {1,+,1}<%for.body> to i64)) + %a) @@ -85,7 +84,6 @@ exit: ; CHECK: Memory dependences are safe ; CHECK: SCEV assumptions: ; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: -; CHECK-NEXT: {0,+,1}<%for.body> Added Flags: define void @test2(i64 %x, ptr %a) { entry: br label %for.body diff --git a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll index 1a07805c2614f..4f595b44ae5fd 100644 --- a/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/nssw-predicate-implied.ll @@ -3,7 +3,7 @@ target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-n32:64-S128-Fn32" -; FIXME: {0,+,3} implies {0,+,2}. +; {0,+,3} [nssw] implies {0,+,2} [nssw] define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2' ; CHECK-NEXT: loop: @@ -26,7 +26,6 @@ define void @wrap_check_iv.3_implies_iv.2(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {0,+,3}<%loop> Added Flags: -; CHECK-NEXT: {0,+,2}<%loop> Added Flags: ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2: @@ -59,7 +58,7 @@ exit: ret void } -; FIXME: {2,+,2} implies {0,+,2}. +; {2,+,2} [nssw] implies {0,+,2} [nssw]. define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %dst, ptr %src) { ; CHECK-LABEL: 'wrap_check_iv.3_implies_iv.2_different_start' ; CHECK-NEXT: loop: @@ -82,7 +81,6 @@ define void @wrap_check_iv.3_implies_iv.2_different_start(i32 noundef %N, ptr %d ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. ; CHECK-NEXT: SCEV assumptions: ; CHECK-NEXT: {2,+,2}<%loop> Added Flags: -; CHECK-NEXT: {0,+,2}<%loop> Added Flags: ; CHECK-EMPTY: ; CHECK-NEXT: Expressions re-written: ; CHECK-NEXT: [PSE] %gep.iv.2 = getelementptr inbounds i32, ptr %src, i64 %ext.iv.2: