From 2a1c4712110cb6617517bbad39c6cb12e8c4a3d3 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 4 Mar 2024 10:59:21 +0000 Subject: [PATCH 1/3] [VPlan] Replace disjoint or with add instead of dropping disjoint. Dropping disjoint from an OR may yield incorrect results, as some analysis may have converted it to an Add implicitly (e.g. SCEV used for dependence analysis). Instead, replace it with an equivalent Add. This is possible as all users of the disjoint OR only access lanes where the operands are disjoint or poison otherwise. Note that replacing all disjoint ORs with ADDs instead of dropping the flags is not strictly necessary. It is only needed for disjoint ORs that SCEV treated as ADDs, but those are not tracked. --- .../Vectorize/LoopVectorizationPlanner.h | 3 +++ llvm/lib/Transforms/Vectorize/VPlan.h | 8 ++++++ .../Transforms/Vectorize/VPlanPatternMatch.h | 27 ++++++++++++++++--- .../Transforms/Vectorize/VPlanTransforms.cpp | 18 +++++++++++++ 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index a7ebf78e54ceb..b94859864fff3 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -68,6 +68,9 @@ class VPBuilder { public: VPBuilder() = default; VPBuilder(VPBasicBlock *InsertBB) { setInsertPoint(InsertBB); } + VPBuilder(VPRecipeBase *InsertPt) { + setInsertPoint(InsertPt->getParent(), InsertPt->getIterator()); + } /// Clear the insertion point: created instructions will not be inserted into /// a block. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 16c09a83e777d..b565b4351e16d 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1127,6 +1127,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe { return WrapFlags.HasNSW; } + bool isDisjoint() const { + assert(OpType == OperationType::DisjointOp && + "recipe cannot have a disjoing flag"); + return DisjointFlags.IsDisjoint; + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printFlags(raw_ostream &O) const; #endif @@ -2136,6 +2142,8 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags { assert(isPredicated() && "Trying to get the mask of a unpredicated recipe"); return getOperand(getNumOperands() - 1); } + + unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); } }; /// A recipe for generating conditional branches on the bits of a mask. diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index b90c588b60756..4b5b6b8cc3dbc 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -73,12 +73,12 @@ template struct UnaryVPInstruction_match { } }; -template -struct BinaryVPInstruction_match { +template +struct BinaryRecipe_match { Op0_t Op0; Op1_t Op1; - BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} + BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} bool match(const VPValue *V) { auto *DefR = V->getDefiningRecipe(); @@ -86,15 +86,27 @@ struct BinaryVPInstruction_match { } bool match(const VPRecipeBase *R) { - auto *DefR = dyn_cast(R); + auto *DefR = dyn_cast(R); if (!DefR || DefR->getOpcode() != Opcode) return false; assert(DefR->getNumOperands() == 2 && "recipe with matched opcode does not have 2 operands"); return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1)); } + + bool match(const VPSingleDefRecipe *R) { + return match(static_cast(R)); + } }; +template +using BinaryVPInstruction_match = + BinaryRecipe_match; + +template +using BinaryVPReplicate_match = + BinaryRecipe_match; + template inline UnaryVPInstruction_match m_VPInstruction(const Op0_t &Op0) { @@ -130,6 +142,13 @@ inline BinaryVPInstruction_match m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } + +template +inline BinaryVPReplicate_match +m_VPReplicate(const Op0_t &Op0, const Op1_t &Op1) { + return BinaryVPReplicate_match(Op0, Op1); +} + } // namespace VPlanPatternMatch } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 9d6deb802e209..818647d5ea1ba 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1249,6 +1249,24 @@ void VPlanTransforms::dropPoisonGeneratingRecipes( // load/store. If the underlying instruction has poison-generating flags, // drop them directly. if (auto *RecWithFlags = dyn_cast(CurRec)) { + VPValue *A, *B; + using namespace llvm::VPlanPatternMatch; + // Dropping disjoint from an OR may yield incorrect results, as some + // analysis may have converted it to an Add implicitly (e.g. SCEV used + // for dependence analysis). Instead, replace it with an equivalent Add. + // This is possible as all users of the disjoint OR only access lanes + // where the operands are disjoint or poison otherwise. + if (match(RecWithFlags, + m_VPReplicate(m_VPValue(A), m_VPValue(B))) && + RecWithFlags->isDisjoint()) { + VPBuilder Builder(RecWithFlags); + VPInstruction *New = Builder.createOverflowingOp( + Instruction::Add, {A, B}, {false, false}, + RecWithFlags->getDebugLoc()); + RecWithFlags->replaceAllUsesWith(New); + RecWithFlags->eraseFromParent(); + CurRec = New; + } RecWithFlags->dropPoisonGeneratingFlags(); } else { Instruction *Instr = dyn_cast_or_null( From f4202fc25fe281b91c960cf38fc94a8814a0267c Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 18 Mar 2024 15:10:36 +0000 Subject: [PATCH 2/3] !fixup fix merge, update test. --- .../Transforms/Vectorize/VPlanPatternMatch.h | 47 ++----------------- .../Transforms/Vectorize/VPlanTransforms.cpp | 2 +- .../Transforms/LoopVectorize/X86/pr81872.ll | 2 +- 3 files changed, 7 insertions(+), 44 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 158c2aca5fbb4..a03a408686ef1 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -176,33 +176,6 @@ struct BinaryRecipe_match { } }; -<<<<<<< HEAD -template -struct BinaryRecipe_match { - Op0_t Op0; - Op1_t Op1; - - BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} - - bool match(const VPValue *V) { - auto *DefR = V->getDefiningRecipe(); - return DefR && match(DefR); - } - - bool match(const VPRecipeBase *R) { - auto *DefR = dyn_cast(R); - if (!DefR || DefR->getOpcode() != Opcode) - return false; - assert(DefR->getNumOperands() == 2 && - "recipe with matched opcode does not have 2 operands"); - return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1)); - } - - bool match(const VPSingleDefRecipe *R) { - return match(static_cast(R)); - } -}; -======= template using BinaryVPInstruction_match = BinaryRecipe_match; @@ -211,15 +184,6 @@ template using AllBinaryRecipe_match = BinaryRecipe_match; ->>>>>>> origin/main - -template -using BinaryVPInstruction_match = - BinaryRecipe_match; - -template -using BinaryVPReplicate_match = - BinaryRecipe_match; template inline UnaryVPInstruction_match @@ -257,12 +221,6 @@ m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } -template -inline BinaryVPReplicate_match -m_VPReplicate(const Op0_t &Op0, const Op1_t &Op1) { - return BinaryVPReplicate_match(Op0, Op1); -} - template inline AllUnaryRecipe_match m_Unary(const Op0_t &Op0) { return AllUnaryRecipe_match(Op0); @@ -303,6 +261,11 @@ m_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } +template +inline AllBinaryRecipe_match +m_Or(const Op0_t &Op0, const Op1_t &Op1) { + return m_Binary(Op0, Op1); +} } // namespace VPlanPatternMatch } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 213abe5a8d8ef..609e1538e7cd9 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1224,7 +1224,7 @@ void VPlanTransforms::dropPoisonGeneratingRecipes( // This is possible as all users of the disjoint OR only access lanes // where the operands are disjoint or poison otherwise. if (match(RecWithFlags, - m_VPReplicate(m_VPValue(A), m_VPValue(B))) && + m_Or(m_VPValue(A), m_VPValue(B))) && RecWithFlags->isDisjoint()) { VPBuilder Builder(RecWithFlags); VPInstruction *New = Builder.createOverflowingOp( diff --git a/llvm/test/Transforms/LoopVectorize/X86/pr81872.ll b/llvm/test/Transforms/LoopVectorize/X86/pr81872.ll index 14acb6f57aa0c..3f38abc75a583 100644 --- a/llvm/test/Transforms/LoopVectorize/X86/pr81872.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/pr81872.ll @@ -29,7 +29,7 @@ define void @test(ptr noundef align 8 dereferenceable_or_null(16) %arr) #0 { ; CHECK-NEXT: [[TMP2:%.*]] = and <4 x i64> [[VEC_IND]], ; CHECK-NEXT: [[TMP3:%.*]] = icmp eq <4 x i64> [[TMP2]], zeroinitializer ; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> [[TMP1]], <4 x i1> [[TMP3]], <4 x i1> zeroinitializer -; CHECK-NEXT: [[TMP5:%.*]] = or i64 [[TMP0]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = add i64 [[TMP0]], 1 ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i64, ptr [[ARR]], i64 [[TMP5]] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i64, ptr [[TMP6]], i32 0 ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i64, ptr [[TMP7]], i32 -3 From dc92c04155c9f6d1462c8ebd950d30ab627cb018 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 18 Mar 2024 15:11:35 +0000 Subject: [PATCH 3/3] !fixup fix formatting --- llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 609e1538e7cd9..c6ec99fbbf0a0 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1223,8 +1223,7 @@ void VPlanTransforms::dropPoisonGeneratingRecipes( // for dependence analysis). Instead, replace it with an equivalent Add. // This is possible as all users of the disjoint OR only access lanes // where the operands are disjoint or poison otherwise. - if (match(RecWithFlags, - m_Or(m_VPValue(A), m_VPValue(B))) && + if (match(RecWithFlags, m_Or(m_VPValue(A), m_VPValue(B))) && RecWithFlags->isDisjoint()) { VPBuilder Builder(RecWithFlags); VPInstruction *New = Builder.createOverflowingOp(