Skip to content

[SLP] Support ordered FAdd reductions in SLPVectorizer #146570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sc-clulzze
Copy link

This patch adds initial support for ordered floating point reduction in SLPVectorizer, currently only FAdd

fixes #50590

Copy link

github-actions bot commented Jul 1, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: None (sc-clulzze)

Changes

This patch adds initial support for ordered floating point reduction in SLPVectorizer, currently only FAdd

fixes #50590


Patch is 44.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146570.diff

5 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+160-29)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll (+8-30)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/phi.ll (+21-32)
  • (added) llvm/test/Transforms/SLPVectorizer/fadd-scalar-remainder.ll (+93)
  • (added) llvm/test/Transforms/SLPVectorizer/fadd-vectorize.ll (+323)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 0941bf61953f1..2c7929d91121f 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -110,11 +110,16 @@ using namespace std::placeholders;
 #define SV_NAME "slp-vectorizer"
 #define DEBUG_TYPE "SLP"
 
+STATISTIC(NumFaddVectorized, "Number of vectorized fadd reductions");
 STATISTIC(NumVectorInstructions, "Number of vector instructions generated");
 
 DEBUG_COUNTER(VectorizedGraphs, "slp-vectorized",
               "Controls which SLP graphs should be vectorized.");
 
+static cl::opt<bool> SLPEnableOrderedFPReductions(
+    "slp-ordered-fp-reds", cl::init(true), cl::Hidden,
+    cl::desc("Enable vectorization of ordered floating point reductions"));
+
 static cl::opt<bool>
     RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden,
                         cl::desc("Run the SLP vectorization passes"));
@@ -1850,6 +1855,11 @@ class BoUpSLP {
     return VectorizableTree.front()->Scalars;
   }
 
+  bool areAllEntriesIdentityOrdered() const {
+    return all_of(VectorizableTree,
+                  [&](auto &Entry) { return Entry->ReorderIndices.empty(); });
+  }
+
   /// Returns the type/is-signed info for the root node in the graph without
   /// casting.
   std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
@@ -21774,6 +21784,8 @@ class HorizontalReduction {
   /// signedness.
   SmallVector<std::tuple<Value *, unsigned, bool>> VectorValuesAndScales;
 
+  SmallVector<Value *, 2> InitialFAddValues;
+
   static bool isCmpSelMinMax(Instruction *I) {
     return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) &&
            RecurrenceDescriptor::isMinMaxRecurrenceKind(getRdxKind(I));
@@ -21787,6 +21799,14 @@ class HorizontalReduction {
            (match(I, m_LogicalAnd()) || match(I, m_LogicalOr()));
   }
 
+  bool isOrderedFaddReduction() const {
+    if (!isa<Instruction>(ReductionRoot))
+      return false;
+    auto *I = cast<Instruction>(ReductionRoot);
+    return (RdxKind == RecurKind::FAdd) &&
+           !I->getFastMathFlags().allowReassoc();
+  }
+
   /// Checks if instruction is associative and can be vectorized.
   static bool isVectorizable(RecurKind Kind, Instruction *I) {
     if (Kind == RecurKind::None)
@@ -21807,6 +21827,9 @@ class HorizontalReduction {
     if (Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimum)
       return true;
 
+    if (Kind == RecurKind::FAdd && SLPEnableOrderedFPReductions)
+      return true;
+
     return I->isAssociative();
   }
 
@@ -22066,6 +22089,37 @@ class HorizontalReduction {
            (I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode()));
   }
 
+  bool checkOperandsOrder() const {
+    auto OpsVec = reverse(ReductionOps[0]);
+    if (!isOrderedFaddReduction() || empty(OpsVec))
+      return false;
+    Value *PrevOperand = *OpsVec.begin();
+    for (auto *I : drop_begin(OpsVec)) {
+      Value *Op1 = cast<BinaryOperator>(I)->getOperand(0);
+      if (Op1 != PrevOperand)
+        return false;
+      PrevOperand = I;
+    }
+    return true;
+  }
+
+  bool checkFastMathFlags() const {
+    for (auto OpsVec : ReductionOps) {
+      if (OpsVec.size() <= 1)
+        continue;
+      Value *V = *OpsVec.begin();
+      if (!isa<FPMathOperator>(V))
+        continue;
+      bool Flag = cast<Instruction>(V)->getFastMathFlags().allowReassoc();
+      auto It = find_if(drop_begin(OpsVec), [&](Value *I) {
+        auto CurFlag = cast<Instruction>(I)->getFastMathFlags().allowReassoc();
+        return (Flag != CurFlag);
+      });
+      if (It != OpsVec.end())
+        return false;
+    }
+    return true;
+  }
 public:
   HorizontalReduction() = default;
 
@@ -22180,9 +22234,10 @@ class HorizontalReduction {
       // Add reduction values. The values are sorted for better vectorization
       // results.
       for (Value *V : PossibleRedVals) {
-        size_t Key, Idx;
-        std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
-                                               /*AllowAlternate=*/false);
+        size_t Key = 0, Idx = 0;
+        if (!isOrderedFaddReduction())
+          std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
+                                                /*AllowAlternate=*/false);
         ++PossibleReducedVals[Key][Idx]
               .insert(std::make_pair(V, 0))
               .first->second;
@@ -22200,13 +22255,15 @@ class HorizontalReduction {
            It != E; ++It) {
         PossibleRedValsVect.emplace_back();
         auto RedValsVect = It->second.takeVector();
-        stable_sort(RedValsVect, llvm::less_second());
+        if (!isOrderedFaddReduction())
+          stable_sort(RedValsVect, llvm::less_second());
         for (const std::pair<Value *, unsigned> &Data : RedValsVect)
           PossibleRedValsVect.back().append(Data.second, Data.first);
       }
-      stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
-        return P1.size() > P2.size();
-      });
+      if (!isOrderedFaddReduction())
+        stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
+          return P1.size() > P2.size();
+        });
       int NewIdx = -1;
       for (ArrayRef<Value *> Data : PossibleRedValsVect) {
         if (NewIdx < 0 ||
@@ -22226,9 +22283,19 @@ class HorizontalReduction {
     }
     // Sort the reduced values by number of same/alternate opcode and/or pointer
     // operand.
-    stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) {
-      return P1.size() > P2.size();
-    });
+    if (!isOrderedFaddReduction())
+      stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) {
+        return P1.size() > P2.size();
+      });
+
+    if (isOrderedFaddReduction() &&
+        (ReducedVals.size() != 1 || ReducedVals[0].size() == 2 ||
+         !checkOperandsOrder()))
+      return false;
+
+    if (!checkFastMathFlags())
+      return false;
+
     return true;
   }
 
@@ -22423,7 +22490,7 @@ class HorizontalReduction {
       // original scalar identity operations on matched horizontal reductions).
       IsSupportedHorRdxIdentityOp = RdxKind != RecurKind::Mul &&
                                     RdxKind != RecurKind::FMul &&
-                                    RdxKind != RecurKind::FMulAdd;
+                                    RdxKind != RecurKind::FMulAdd && !isOrderedFaddReduction();
       // Gather same values.
       SmallMapVector<Value *, unsigned, 16> SameValuesCounter;
       if (IsSupportedHorRdxIdentityOp)
@@ -22524,6 +22591,8 @@ class HorizontalReduction {
         return IsAnyRedOpGathered;
       };
       bool AnyVectorized = false;
+      Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);;
+      Instruction *InsertPt = RdxRootInst;
       SmallDenseSet<std::pair<unsigned, unsigned>, 8> IgnoredCandidates;
       while (Pos < NumReducedVals - ReduxWidth + 1 &&
              ReduxWidth >= ReductionLimit) {
@@ -22684,8 +22753,6 @@ class HorizontalReduction {
 
         // Emit a reduction. If the root is a select (min/max idiom), the insert
         // point is the compare condition of that select.
-        Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
-        Instruction *InsertPt = RdxRootInst;
         if (IsCmpSelMinMax)
           InsertPt = GetCmpForMinMaxReduction(RdxRootInst);
 
@@ -22738,6 +22805,41 @@ class HorizontalReduction {
           if (!V.isVectorized(RdxVal))
             RequiredExtract.insert(RdxVal);
         }
+
+        auto FirstIt = find_if(ReducedVals[0], [&](Value *RdxVal) {
+          return VectorizedVals.lookup(RdxVal);
+        });
+        auto LastIt = find_if(reverse(ReducedVals[0]), [&](Value *RdxVal) {
+          return VectorizedVals.lookup(RdxVal);
+        });
+        if (isOrderedFaddReduction()) {
+          //[FirstIt, LastIt] - range of vectorized Vals, we need it to get last
+          // non-vectorized Val at the beginning and it's ReductionOp and first
+          // non-vectorized Val at the end and it's ReductinoOp
+          // fadd - initial value for reduction
+          // fadd - v
+          // fadd - v
+          // fadd - v
+          // fadd - v
+          // fadd - scalar remainder
+          if (LastIt != ReducedVals[0].rend())
+            ReductionRoot =
+                cast<Instruction>(ReducedValsToOps.find(*LastIt)->second[0]);
+
+          if (InitialFAddValues.empty()) {
+            auto *FAddBinOp = cast<BinaryOperator>(
+                ReducedValsToOps.find(*FirstIt)->second[0]);
+            Value *InitialFAddValue = ConstantExpr::getBinOpIdentity(
+                FAddBinOp->getOpcode(), FAddBinOp->getType());
+            if (FirstIt != ReducedVals[0].end()) {
+              auto *Op1 = FAddBinOp->getOperand(0);
+              if (!isa<PoisonValue>(Op1))
+                InitialFAddValue = Op1;
+            }
+            InitialFAddValues.push_back(InitialFAddValue);
+          }
+        }
+
         Pos += ReduxWidth;
         Start = Pos;
         ReduxWidth = NumReducedVals - Pos;
@@ -22755,10 +22857,27 @@ class HorizontalReduction {
         continue;
       }
     }
-    if (!VectorValuesAndScales.empty())
-      VectorizedTree = GetNewVectorizedTree(
-          VectorizedTree,
-          emitReduction(Builder, *TTI, ReductionRoot->getType()));
+    if (!VectorValuesAndScales.empty()) {
+      if (!isOrderedFaddReduction()) {
+        VectorizedTree = GetNewVectorizedTree(
+            VectorizedTree,
+            emitReduction(Builder, *TTI, ReductionRoot->getType()));
+      } else {
+        for (auto V : VectorValuesAndScales) {
+          Value *InitialFAddValue = InitialFAddValues.back();
+          VectorizedTree = Builder.CreateFAddReduce(InitialFAddValue, std::get<0>(V));
+          InitialFAddValues.push_back(VectorizedTree);
+        }
+        auto LastIt = find_if(reverse(ReducedVals[0]), [&](Value *RdxVal) {
+          return VectorizedVals.lookup(RdxVal);
+        });
+        for_each(reverse(make_range(LastIt.base(), ReducedVals[0].end())),
+                   [&](Value *V) {
+                     ReducedValsToOps.find(V)->second[0]->moveAfter(
+                         cast<Instruction>(VectorizedTree));
+                   });
+      }
+    }
     if (VectorizedTree) {
       // Reorder operands of bool logical op in the natural order to avoid
       // possible problem with poison propagation. If not possible to reorder
@@ -22846,15 +22965,18 @@ class HorizontalReduction {
             ExtraReductions.emplace_back(RedOp, RdxVal);
         }
       }
-      // Iterate through all not-vectorized reduction values/extra arguments.
-      bool InitStep = true;
-      while (ExtraReductions.size() > 1) {
-        SmallVector<std::pair<Instruction *, Value *>> NewReds =
-            FinalGen(ExtraReductions, InitStep);
-        ExtraReductions.swap(NewReds);
-        InitStep = false;
+
+      if (!isOrderedFaddReduction()) {
+        // Iterate through all not-vectorized reduction values/extra arguments.
+        bool InitStep = true;
+        while (ExtraReductions.size() > 1) {
+          SmallVector<std::pair<Instruction *, Value *>> NewReds =
+              FinalGen(ExtraReductions, InitStep);
+          ExtraReductions.swap(NewReds);
+          InitStep = false;
+        }
+        VectorizedTree = ExtraReductions.front().second;
       }
-      VectorizedTree = ExtraReductions.front().second;
 
       ReductionRoot->replaceAllUsesWith(VectorizedTree);
 
@@ -22868,21 +22990,28 @@ class HorizontalReduction {
         IgnoreSet.insert_range(RdxOps);
 #endif
       for (ArrayRef<Value *> RdxOps : ReductionOps) {
+        SmallVector<Value *, 4> RdxOpsForDeletion;
         for (Value *Ignore : RdxOps) {
-          if (!Ignore)
+          if (!Ignore || (isOrderedFaddReduction() && !Ignore->use_empty() &&
+                          !any_of(cast<Instruction>(Ignore)->operands(),
+                                  [](const Value *Val) {
+                                    return isa<PoisonValue>(Val);
+                                  })))
             continue;
 #ifndef NDEBUG
           for (auto *U : Ignore->users()) {
-            assert(IgnoreSet.count(U) &&
-                   "All users must be either in the reduction ops list.");
+            assert((IgnoreSet.count(U) ||
+                   isOrderedFaddReduction()) &&
+                       "All users must be either in the reduction ops list.");
           }
 #endif
           if (!Ignore->use_empty()) {
             Value *P = PoisonValue::get(Ignore->getType());
             Ignore->replaceAllUsesWith(P);
           }
+          RdxOpsForDeletion.push_back(Ignore);
         }
-        V.removeInstructionsAndOperands(RdxOps, VectorValuesAndScales);
+        V.removeInstructionsAndOperands(ArrayRef(RdxOpsForDeletion), VectorValuesAndScales);
       }
     } else if (!CheckForReusedReductionOps) {
       for (ReductionOpsType &RdxOps : ReductionOps)
@@ -22961,6 +23090,8 @@ class HorizontalReduction {
           continue;
         }
         InstructionCost ScalarCost = 0;
+        if (RdxVal->use_empty())
+          continue;
         for (User *U : RdxVal->users()) {
           auto *RdxOp = cast<Instruction>(U);
           if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll b/llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll
index f16c879c451c2..8f541a3dface3 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll
@@ -10,21 +10,10 @@
 
 define double @dot4f64(ptr dereferenceable(32) %ptrx, ptr dereferenceable(32) %ptry) {
 ; CHECK-LABEL: @dot4f64(
-; CHECK-NEXT:    [[PTRX2:%.*]] = getelementptr inbounds double, ptr [[PTRX:%.*]], i64 2
-; CHECK-NEXT:    [[PTRY2:%.*]] = getelementptr inbounds double, ptr [[PTRY:%.*]], i64 2
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x double>, ptr [[PTRX]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x double>, ptr [[PTRY]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul <2 x double> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = load <2 x double>, ptr [[PTRX2]], align 4
-; CHECK-NEXT:    [[TMP5:%.*]] = load <2 x double>, ptr [[PTRY2]], align 4
-; CHECK-NEXT:    [[TMP6:%.*]] = fmul <2 x double> [[TMP4]], [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x double> [[TMP3]], i32 0
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x double> [[TMP3]], i32 1
-; CHECK-NEXT:    [[DOT01:%.*]] = fadd double [[TMP7]], [[TMP8]]
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x double> [[TMP6]], i32 0
-; CHECK-NEXT:    [[DOT012:%.*]] = fadd double [[DOT01]], [[TMP9]]
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x double> [[TMP6]], i32 1
-; CHECK-NEXT:    [[DOT0123:%.*]] = fadd double [[DOT012]], [[TMP10]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x double>, ptr [[PTRX:%.*]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x double>, ptr [[PTRY:%.*]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul <4 x double> [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[DOT0123:%.*]] = call double @llvm.vector.reduce.fadd.v4f64(double -0.000000e+00, <4 x double> [[TMP3]])
 ; CHECK-NEXT:    ret double [[DOT0123]]
 ;
   %ptrx1 = getelementptr inbounds double, ptr %ptrx, i64 1
@@ -53,21 +42,10 @@ define double @dot4f64(ptr dereferenceable(32) %ptrx, ptr dereferenceable(32) %p
 
 define float @dot4f32(ptr dereferenceable(16) %ptrx, ptr dereferenceable(16) %ptry) {
 ; CHECK-LABEL: @dot4f32(
-; CHECK-NEXT:    [[PTRX2:%.*]] = getelementptr inbounds float, ptr [[PTRX:%.*]], i64 2
-; CHECK-NEXT:    [[PTRY2:%.*]] = getelementptr inbounds float, ptr [[PTRY:%.*]], i64 2
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x float>, ptr [[PTRX]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x float>, ptr [[PTRY]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = fmul <2 x float> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = load <2 x float>, ptr [[PTRX2]], align 4
-; CHECK-NEXT:    [[TMP5:%.*]] = load <2 x float>, ptr [[PTRY2]], align 4
-; CHECK-NEXT:    [[TMP6:%.*]] = fmul <2 x float> [[TMP4]], [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
-; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
-; CHECK-NEXT:    [[DOT01:%.*]] = fadd float [[TMP7]], [[TMP8]]
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <2 x float> [[TMP6]], i32 0
-; CHECK-NEXT:    [[DOT012:%.*]] = fadd float [[DOT01]], [[TMP9]]
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <2 x float> [[TMP6]], i32 1
-; CHECK-NEXT:    [[DOT0123:%.*]] = fadd float [[DOT012]], [[TMP10]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr [[PTRX:%.*]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x float>, ptr [[PTRY:%.*]], align 4
+; CHECK-NEXT:    [[TMP3:%.*]] = fmul <4 x float> [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    [[DOT0123:%.*]] = call float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP3]])
 ; CHECK-NEXT:    ret float [[DOT0123]]
 ;
   %ptrx1 = getelementptr inbounds float, ptr %ptrx, i64 1
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/phi.ll b/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
index 17ae33652b6d8..c1a0c293ef9b9 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
@@ -136,44 +136,39 @@ for.end:                                          ; preds = %for.body
 define float @foo3(ptr nocapture readonly %A) #0 {
 ; CHECK-LABEL: @foo3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds float, ptr [[A:%.*]], i64 1
-; CHECK-NEXT:    [[TMP0:%.*]] = load <2 x float>, ptr [[A]], align 4
-; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr [[ARRAYIDX1]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <2 x float> [[TMP0]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x float>, ptr [[ARRAYIDX1:%.*]], align 4
+; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds float, ptr [[ARRAYIDX1]], i64 4
+; CHECK-NEXT:    [[TMP2:%.*]] = load float, ptr [[ARRAYIDX4]], align 4
+; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> poison, <2 x i32> <i32 0, i32 1>
 ; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
 ; CHECK:       for.body:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[R_052:%.*]] = phi float [ [[TMP2]], [[ENTRY]] ], [ [[ADD6:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP3:%.*]] = phi <4 x float> [ [[TMP1]], [[ENTRY]] ], [ [[TMP15:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP4:%.*]] = phi <2 x float> [ [[TMP0]], [[ENTRY]] ], [ [[TMP7:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x float> [[TMP4]], i32 0
-; CHECK-NEXT:    [[MUL:%.*]] = fmul float [[TMP5]], 7.000000e+00
-; CHECK-NEXT:    [[ADD6]] = fadd float [[R_052]], [[MUL]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = add nsw i64 [[INDVARS_IV]], 2
-; CHECK-NEXT:    [[ARRAYIDX14:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[TMP6]]
+; CHECK-NEXT:    [[ARRAYIDX14:%.*]] = getelementptr inbounds float, ptr [[ARRAYIDX1]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP9:%.*]] = load float, ptr [[ARRAYIDX14]], align 4
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 3
-; CHECK-NEXT:    [[ARRAYIDX19:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[INDVARS_IV_NEXT]]
-; CHECK-NEXT:    [[TMP8:%.*]] = load <2 x float>, ptr [[ARRAYIDX14]], align 4
+; CHECK-NEXT:    [[ARRAYIDX19:%.*]] = getelementptr inbounds float, ptr [[ARRAYIDX1]], i64 [[INDVARS_IV_NEXT]]
+; CHECK-NEXT:    [[TMP11:%.*]] = add nsw i64 [[INDVARS_IV]], 4
+; CHECK-NEXT:    [[ARRAYIDX24:%.*]] = getelementptr inbounds float, ptr [[ARRAYIDX1]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP8:%.*]] = load float, ptr [[ARRAYIDX24]], align 4
 ; CHECK-NEXT:    [[TMP7]] = load <2 x float>, ptr [[ARRAYIDX19]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <4 x i32> <i32 poison, i32 0, i32 1, i32 poison>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x float> [[TMP4]], <2 x float> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x float> [[TMP9]], <4 x float> [[TMP10]], <4 x i32> <i32 5, i32 1, i32 2, i32 poison>
+; CHECK-N...
[truncated]

@sc-clulzze
Copy link
Author

@alexey-bataev @preames @RKSimon could you please take a look?

size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
/*AllowAlternate=*/false);
size_t Key = 0, Idx = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think better to separate it from the "associative" analysis and implement it in a separate function, in case if associative does not work. It will be much easier to read and maintain. Most of the logic for associative reductions can be dropped

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SLP] Allow sequential fadd reductions
3 participants