Skip to content

[mlir][Vector] Remove trivial uses of vector.extractelement/vector.insertelement (1/N) #116053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
Value result = rewriter.create<vector::ExtractElementOp>(
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = cast<VectorType>(op.getOut().getType());
Expand All @@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
Value elemA = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());

if (elemsThisOp == 2) {
elemB = rewriter.create<vector::ExtractElementOp>(
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
}

thisResult =
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
Value idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
Expand All @@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
resultType.getScalableDims().back()),
idx);
transferReadIdxs.push_back(
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
}

// `tensor.extract_element` is always in-bounds, hence the following holds.
Expand Down Expand Up @@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
if (readType.getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
ArrayRef<int64_t>());

LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
Expand Down Expand Up @@ -2273,7 +2272,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue =
rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
}

result = rewriter.create<vector::InsertElementOp>(
loc, reductionOp->getResult(0), result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
result, i);
}

rewriter.replaceOp(rootOp, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,6 @@
using namespace mlir;
using namespace mlir::vector;

// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = cast<VectorType>(into.getType());
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}

// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = cast<VectorType>(vector.getType());
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}

/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have different ranks.
///
Expand Down Expand Up @@ -173,11 +151,13 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
Value extractedSource =
rewriter.create<ExtractOp>(loc, op.getSource(), idx);
if (isa<VectorType>(extractedSource.getType())) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
Value extractedDest =
rewriter.create<ExtractOp>(loc, op.getDest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
Expand All @@ -186,7 +166,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
getI64SubArray(op.getStrides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
}

rewriter.replaceOp(op, res);
Expand Down Expand Up @@ -277,8 +257,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
};

/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
/// For such cases, we can rewrite it to ExtractOp + lower rank
/// ExtractStridedSliceOp + InsertOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
Expand Down Expand Up @@ -317,12 +297,12 @@ class DecomposeNDExtractStridedSlice
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.getVector(), off);
Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
getI64SubArray(op.getSizes(), /* dropFront=*/1),
getI64SubArray(op.getStrides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
res = rewriter.create<InsertOp>(loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
func.func @scalar_trunc(%v: f32) -> f16{
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
// CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
// CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
// CHECK: return %[[extract]] : f16
%w = arith.truncf %v : f32 to f16
return %w : f16
Expand All @@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
// CHECK-LABEL: @vector_trunc
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem0:.*]] = vector.extract %[[value]]
// CHECK: %[[elem1:.*]] = vector.extract %[[value]]
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: return %[[ret]]
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
Expand All @@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
// CHECK-LABEL: @vector_trunc_long
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
// CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
// CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
// CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
// CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
// CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
// CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
// CHECK: %[[elem4:.*]] = vector.extract %[[value]][4]
// CHECK: %[[elem5:.*]] = vector.extract %[[value]][5]
// CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
// CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem6:.*]] = vector.extract %[[value]]
// CHECK: %[[elem7:.*]] = vector.extract %[[value]]
// CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
// CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
// CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
// CHECK: %[[elem8:.*]] = vector.extract %[[value]]
// CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
// CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
// CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/vectorization-scalable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func.func @vectorize_linalg_index(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) ->
// CHECK: %[[DST_DIM0:.*]] = tensor.dim %[[DST]], %[[C0]] : tensor<?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DST_DIM0]] : vector<[4]xi1>
// CHECK-DAG: %[[STEP:.+]] = vector.step : vector<[4]xindex>
// CHECK-DAG: %[[STEP_ELEMENT:.+]] = vector.extractelement %[[STEP]][%c0_i32 : i32] : vector<[4]xindex>
// CHECK-DAG: %[[STEP_ELEMENT:.+]] = vector.extract %[[STEP]][0] : index from vector<[4]xindex>

// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP_ELEMENT]]], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
Expand Down Expand Up @@ -207,7 +207,7 @@ func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VEC_RD_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[C0_F32]] : tensor<f32>, vector<f32>
// CHECK: %[[ACC_f32:.*]] = vector.extractelement %[[VEC_RD_1]][] : vector<f32>
// CHECK: %[[ACC_f32:.*]] = vector.extract %[[VEC_RD_1]][] : f32 from vector<f32>
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[ACC_f32]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
// CHECK: %[[VEC_f32:.*]] = vector.broadcast %[[REDUCE]] : f32 to vector<f32>
// CHECK: %{{.*}} = vector.transfer_write %[[VEC_f32]], %[[ARG_1]][] : vector<f32>, tensor<f32>
Expand Down
10 changes: 4 additions & 6 deletions mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ module attributes {transform.with_named_sequence} {
func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
// CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
// CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
// CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
// CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
memref.copy %A, %B : memref<f32> to memref<f32>
Expand Down Expand Up @@ -1440,7 +1440,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @reduce_1d(
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
// CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
%f0 = arith.constant 0.000000e+00 : f32
Expand All @@ -1451,8 +1450,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0]
// CHECK-SAME: : vector<32xf32> to f32
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
Expand Down Expand Up @@ -1779,9 +1777,9 @@ module attributes {transform.with_named_sequence} {

// CHECK-LABEL: func @zero_dim_tensor
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
// CHECK: vector.extractelement
// CHECK: vector.extract
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
// CHECK: vector.extractelement
// CHECK: vector.extract
// CHECK: arith.addf {{.*}} : f32
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
Expand Down
12 changes: 4 additions & 8 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
Expand Down Expand Up @@ -98,11 +97,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
Expand Down Expand Up @@ -159,11 +157,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
Expand Down Expand Up @@ -218,11 +215,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
Expand Down
Loading
Loading