From 972bfc37494159694decee2fac075ad682ee9cc8 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 5 Jun 2025 15:06:23 +0100 Subject: [PATCH 1/4] [VectorCombine] Scalarize extracts of ZExt if profitable. Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example. For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder. --- .../Transforms/Vectorize/VectorCombine.cpp | 68 +++++++++++ .../VectorCombine/AArch64/ext-extract.ll | 111 ++++++++++++++---- 2 files changed, 155 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index b2fced47b9527..a7689192fbcbf 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -123,6 +123,7 @@ class VectorCombine { bool foldBinopOfReductions(Instruction &I); bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); + bool scalarizeExtExtract(Instruction &I); bool foldConcatOfBoolMasks(Instruction &I); bool foldPermuteOfBinops(Instruction &I); bool foldShuffleOfBinops(Instruction &I); @@ -1774,6 +1775,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } +bool VectorCombine::scalarizeExtExtract(Instruction &I) { + if (!match(&I, m_ZExt(m_Value()))) + return false; + + // Try to convert a vector zext feeding only extracts to a set of scalar (Src + // << ExtIdx *Size) & (Size -1), if profitable. + auto *Ext = cast(&I); + auto *SrcTy = cast(Ext->getOperand(0)->getType()); + auto *DstTy = cast(Ext->getType()); + + if (DL->getTypeSizeInBits(SrcTy) != + DL->getTypeSizeInBits(DstTy->getElementType())) + return false; + + InstructionCost VectorCost = TTI.getCastInstrCost( + Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind); + unsigned ExtCnt = 0; + bool ExtLane0 = false; + for (User *U : Ext->users()) { + const APInt *Idx; + if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx)))) + return false; + if (cast(U)->use_empty()) + continue; + ExtCnt += 1; + ExtLane0 |= Idx->isZero(); + VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy, + CostKind, Idx->getZExtValue(), U); + } + + Type *ScalarDstTy = DstTy->getElementType(); + InstructionCost ScalarCost = + ExtCnt * TTI.getArithmeticInstrCost( + Instruction::And, ScalarDstTy, CostKind, + {TTI::OK_AnyValue, TTI::OP_None}, + {TTI::OK_NonUniformConstantValue, TTI::OP_None}) + + (ExtCnt - ExtLane0) * + TTI.getArithmeticInstrCost( + + Instruction::LShr, ScalarDstTy, CostKind, + {TTI::OK_AnyValue, TTI::OP_None}, + {TTI::OK_NonUniformConstantValue, TTI::OP_None}); + if (ScalarCost > VectorCost) + return false; + + Value *ScalarV = Ext->getOperand(0); + if (!isGuaranteedNotToBePoison(ScalarV, &AC)) + ScalarV = Builder.CreateFreeze(ScalarV); + ScalarV = Builder.CreateBitCast( + ScalarV, + IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy))); + unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); + Value *EltBitMask = + ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1); + for (auto *U : to_vector(Ext->users())) { + auto *Extract = cast(U); + unsigned Idx = + cast(Extract->getIndexOperand())->getZExtValue(); + auto *S = Builder.CreateLShr( + ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits)); + auto *A = Builder.CreateAnd(S, EltBitMask); + U->replaceAllUsesWith(A); + } + return true; +} + /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" /// to "(bitcast (concat X, Y))" /// where X/Y are bitcasted from i1 mask vectors. @@ -3662,6 +3729,7 @@ bool VectorCombine::run() { if (IsVectorType) { MadeChange |= scalarizeOpOrCmp(I); MadeChange |= scalarizeLoadExtract(I); + MadeChange |= scalarizeExtExtract(I); MadeChange |= scalarizeVPIntrinsic(I); MadeChange |= foldInterleaveIntrinsics(I); } diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll index 09c03991ad7c3..23538589ae32c 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll @@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 +; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP9]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) { ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef( ; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255 +; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255 +; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8 +; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255 +; CHECK-NEXT: [[TMP7:%.*]] = lshr i32 [[TMP0]], 0 +; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP7]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP8]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP6]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP4]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) { ; CHECK-LABEL: define void @zext_v4i16_all_lanes_used( ; CHECK-SAME: <4 x i16> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48 +; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 65535 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32 +; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16 +; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535 +; CHECK-NEXT: [[TMP8:%.*]] = lshr i64 [[TMP1]], 0 +; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP8]], 65535 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP9]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP7]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) { ; CHECK-LABEL: define void @zext_v2i32_all_lanes_used( ; CHECK-SAME: <2 x i32> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32 +; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1 -; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: From a50d573346fab906689be715d60de186391cc043 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Thu, 5 Jun 2025 19:02:26 +0100 Subject: [PATCH 2/4] !fixup address comments, thanks! --- .../Transforms/Vectorize/VectorCombine.cpp | 15 +++++----- .../VectorCombine/AArch64/ext-extract.ll | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index a7689192fbcbf..5162993bbd867 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1782,15 +1782,18 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) { // Try to convert a vector zext feeding only extracts to a set of scalar (Src // << ExtIdx *Size) & (Size -1), if profitable. auto *Ext = cast(&I); - auto *SrcTy = cast(Ext->getOperand(0)->getType()); + auto *SrcTy = dyn_cast(Ext->getOperand(0)->getType()); + if (!SrcTy) + return false; auto *DstTy = cast(Ext->getType()); - if (DL->getTypeSizeInBits(SrcTy) != - DL->getTypeSizeInBits(DstTy->getElementType())) + Type *ScalarDstTy = DstTy->getElementType(); + if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy)) return false; - InstructionCost VectorCost = TTI.getCastInstrCost( - Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind); + InstructionCost VectorCost = + TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy, + TTI::CastContextHint::None, CostKind, Ext); unsigned ExtCnt = 0; bool ExtLane0 = false; for (User *U : Ext->users()) { @@ -1805,7 +1808,6 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) { CostKind, Idx->getZExtValue(), U); } - Type *ScalarDstTy = DstTy->getElementType(); InstructionCost ScalarCost = ExtCnt * TTI.getArithmeticInstrCost( Instruction::And, ScalarDstTy, CostKind, @@ -1813,7 +1815,6 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) { {TTI::OK_NonUniformConstantValue, TTI::OP_None}) + (ExtCnt - ExtLane0) * TTI.getArithmeticInstrCost( - Instruction::LShr, ScalarDstTy, CostKind, {TTI::OK_AnyValue, TTI::OP_None}, {TTI::OK_NonUniformConstantValue, TTI::OP_None}); diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll index 23538589ae32c..97e250542f96d 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll @@ -329,3 +329,32 @@ entry: call void @use.i64(i64 %ext.1) ret void } + +define void @zext_nv4i8_all_lanes_used( %src) { +; CHECK-LABEL: define void @zext_nv4i8_all_lanes_used( +; CHECK-SAME: [[SRC:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[EXT9:%.*]] = zext nneg [[SRC]] to +; CHECK-NEXT: [[EXT_0:%.*]] = extractelement [[EXT9]], i64 0 +; CHECK-NEXT: [[EXT_1:%.*]] = extractelement [[EXT9]], i64 1 +; CHECK-NEXT: [[EXT_2:%.*]] = extractelement [[EXT9]], i64 2 +; CHECK-NEXT: [[EXT_3:%.*]] = extractelement [[EXT9]], i64 3 +; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: ret void +; +entry: + %ext9 = zext nneg %src to + %ext.0 = extractelement %ext9, i64 0 + %ext.1 = extractelement %ext9, i64 1 + %ext.2 = extractelement %ext9, i64 2 + %ext.3 = extractelement %ext9, i64 3 + + call void @use.i32(i32 %ext.0) + call void @use.i32(i32 %ext.1) + call void @use.i32(i32 %ext.2) + call void @use.i32(i32 %ext.3) + ret void +} From eac3396018833c7f04c83064af7c5c779ab6a290 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Mon, 30 Jun 2025 14:30:19 +0100 Subject: [PATCH 3/4] !fixup address comments, thanks --- .../Transforms/Vectorize/VectorCombine.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 5162993bbd867..1f009e905a7ce 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1776,12 +1776,13 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { } bool VectorCombine::scalarizeExtExtract(Instruction &I) { - if (!match(&I, m_ZExt(m_Value()))) + auto *Ext = dyn_cast(&I); + if (!Ext) return false; - // Try to convert a vector zext feeding only extracts to a set of scalar (Src - // << ExtIdx *Size) & (Size -1), if profitable. - auto *Ext = cast(&I); + // Try to convert a vector zext feeding only extracts to a set of scalar + // (Src << ExtIdx *Size) & (Size -1) + // if profitable . auto *SrcTy = dyn_cast(Ext->getOperand(0)->getType()); if (!SrcTy) return false; @@ -1822,21 +1823,20 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) { return false; Value *ScalarV = Ext->getOperand(0); - if (!isGuaranteedNotToBePoison(ScalarV, &AC)) + if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast(ScalarV), + &DT)) ScalarV = Builder.CreateFreeze(ScalarV); ScalarV = Builder.CreateBitCast( ScalarV, IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy))); unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); - Value *EltBitMask = - ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1); - for (auto *U : to_vector(Ext->users())) { + unsigned EltBitMask = (1ull << SrcEltSizeInBits) - 1; + for (User *U : Ext->users()) { auto *Extract = cast(U); - unsigned Idx = + uint64_t Idx = cast(Extract->getIndexOperand())->getZExtValue(); - auto *S = Builder.CreateLShr( - ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits)); - auto *A = Builder.CreateAnd(S, EltBitMask); + Value *S = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits); + Value *A = Builder.CreateAnd(S, EltBitMask); U->replaceAllUsesWith(A); } return true; From f4c82b916b24515cb4cdff7d21626bd9b81e9166 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Wed, 2 Jul 2025 19:14:24 +0100 Subject: [PATCH 4/4] !fixup adjust types and naming, thanks --- .../Transforms/Vectorize/VectorCombine.cpp | 10 ++--- .../VectorCombine/AArch64/ext-extract.ll | 40 +++++++------------ 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 1f009e905a7ce..55ee33a178532 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1829,15 +1829,15 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) { ScalarV = Builder.CreateBitCast( ScalarV, IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy))); - unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); - unsigned EltBitMask = (1ull << SrcEltSizeInBits) - 1; + uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); + uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1; for (User *U : Ext->users()) { auto *Extract = cast(U); uint64_t Idx = cast(Extract->getIndexOperand())->getZExtValue(); - Value *S = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits); - Value *A = Builder.CreateAnd(S, EltBitMask); - U->replaceAllUsesWith(A); + Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits); + Value *And = Builder.CreateAnd(LShr, EltBitMask); + U->replaceAllUsesWith(And); } return true; } diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll index 97e250542f96d..60700412686ea 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll @@ -12,13 +12,11 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) { ; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 -; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 ; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 ; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 ; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 ; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 -; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[TMP1]], 0 -; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 255 +; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 @@ -27,7 +25,7 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) { ; CHECK-NEXT: call void @use.i32(i32 [[TMP9]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) -; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -81,7 +79,6 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) { ; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 -; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 ; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 ; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 ; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 @@ -92,7 +89,7 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) { ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 ; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) -; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -114,18 +111,16 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) { ; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 -; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 ; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8 ; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 -; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 0 -; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 ; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) -; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -175,8 +170,7 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) { ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16 ; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 -; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 @@ -200,13 +194,11 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32 ; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24 -; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16 ; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255 ; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8 ; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255 -; CHECK-NEXT: [[TMP7:%.*]] = lshr i32 [[TMP0]], 0 -; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP7]], 255 +; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP0]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 @@ -215,7 +207,7 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) { ; CHECK-NEXT: call void @use.i32(i32 [[TMP8]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP6]]) ; CHECK-NEXT: call void @use.i32(i32 [[TMP4]]) -; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP1]]) ; CHECK-NEXT: ret void ; entry: @@ -271,13 +263,11 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) { ; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]] ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48 -; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 65535 ; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32 ; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535 ; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16 ; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535 -; CHECK-NEXT: [[TMP8:%.*]] = lshr i64 [[TMP1]], 0 -; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP8]], 65535 +; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP1]], 65535 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1 @@ -286,7 +276,7 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) { ; CHECK-NEXT: call void @use.i64(i64 [[TMP9]]) ; CHECK-NEXT: call void @use.i64(i64 [[TMP7]]) ; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) -; CHECK-NEXT: call void @use.i64(i64 [[TMP3]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -310,14 +300,12 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) { ; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]] ; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64 ; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32 -; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295 -; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295 +; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP1]], 4294967295 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1 ; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) -; CHECK-NEXT: call void @use.i64(i64 [[TMP3]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -330,8 +318,8 @@ entry: ret void } -define void @zext_nv4i8_all_lanes_used( %src) { -; CHECK-LABEL: define void @zext_nv4i8_all_lanes_used( +define void @zext_nxv4i8_all_lanes_used( %src) { +; CHECK-LABEL: define void @zext_nxv4i8_all_lanes_used( ; CHECK-SAME: [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg [[SRC]] to