diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 6501a14d87789..6ce1077d81cf5 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4843,25 +4843,43 @@ static bool clusterSortPtrAccesses(ArrayRef VL, Type *ElemTy, return false; // If we have a better order, also sort the base pointers by increasing - // (variable) values if possible, to try and keep the order more regular. - SmallVector> SortedBases; - for (auto &Base : Bases) - SortedBases.emplace_back(Base.first, - Base.first->stripInBoundsConstantOffsets()); - llvm::stable_sort(SortedBases, [](std::pair V1, - std::pair V2) { - const Value *V = V2.second; - while (auto *Gep = dyn_cast(V)) { - if (Gep->getOperand(0) == V1.second) - return true; - V = Gep->getOperand(0); - } - return false; - }); + // (variable) values if possible, to try and keep the order more regular. In + // order to create a valid strict-weak order we cluster by the Root of gep + // chains and sort within each. + SmallVector> SortedBases; + for (auto &Base : Bases) { + Value *Strip = Base.first->stripInBoundsConstantOffsets(); + Value *Root = Strip; + while (auto *Gep = dyn_cast(Root)) + Root = Gep->getOperand(0); + SortedBases.emplace_back(Base.first, Strip, Root); + } + auto *Begin = SortedBases.begin(); + auto *End = SortedBases.end(); + while (Begin != End) { + Value *Root = std::get<2>(*Begin); + auto *Mid = std::stable_partition( + Begin, End, [&Root](auto V) { return std::get<2>(V) == Root; }); + DenseMap> LessThan; + for (auto I = Begin; I < Mid; ++I) + LessThan.try_emplace(std::get<1>(*I)); + for (auto I = Begin; I < Mid; ++I) { + Value *V = std::get<1>(*I); + while (auto *Gep = dyn_cast(V)) { + V = Gep->getOperand(0); + if (LessThan.contains(V)) + LessThan[V][std::get<1>(*I)] = true; + } + } + std::stable_sort(Begin, Mid, [&LessThan](auto &V1, auto &V2) { + return LessThan[std::get<1>(V1)][std::get<1>(V2)]; + }); + Begin = Mid; + } // Collect the final order of sorted indices for (auto Base : SortedBases) - for (auto &T : Bases[Base.first]) + for (auto &T : Bases[std::get<0>(Base)]) SortedIndices.push_back(std::get<2>(T)); assert(SortedIndices.size() == VL.size() && diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll index 6b5503f26fabf..d79aed89b0be7 100644 --- a/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll @@ -428,14 +428,14 @@ define i32 @reduce_blockstrided4x4(ptr nocapture noundef readonly %p1, i32 nound ; CHECK-NEXT: [[TMP5:%.*]] = load <4 x i8>, ptr [[ADD_PTR64]], align 1 ; CHECK-NEXT: [[TMP6:%.*]] = load <4 x i8>, ptr [[ARRAYIDX3_1]], align 1 ; CHECK-NEXT: [[TMP7:%.*]] = load <4 x i8>, ptr [[ARRAYIDX5_1]], align 1 -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> [[TMP1]], <16 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> [[TMP4]], <16 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <16 x i8> [[TMP8]], <16 x i8> [[TMP9]], <16 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> poison, <16 x i32> ; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <16 x i8> [[TMP10]], <16 x i8> [[TMP11]], <16 x i32> ; CHECK-NEXT: [[TMP13:%.*]] = zext <16 x i8> [[TMP12]] to <16 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP3]], <16 x i32> -; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <16 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP6]], <16 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> ; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <16 x i8> [[TMP14]], <16 x i8> [[TMP15]], <16 x i32> ; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <4 x i8> [[TMP7]], <4 x i8> poison, <16 x i32> ; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <16 x i8> [[TMP16]], <16 x i8> [[TMP17]], <16 x i32>