From 2f7fc4bcc667aeea40357dfe1ac5d12a2162f61b Mon Sep 17 00:00:00 2001 From: Hugo Date: Fri, 17 May 2024 18:49:31 +0800 Subject: [PATCH 1/8] [mlir][vector] Support MaskableOpRewritePattern of Op without a result. --- mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 030be328e97fd..bf9694556f901 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -157,7 +157,10 @@ struct MaskableOpRewritePattern : OpRewritePattern { if (failed(newOp)) return failure(); - rewriter.replaceOp(rootOp, *newOp); + if (rootOp->getNumResults() == 0 || *newOp == Value()) + rewriter.eraseOp(rootOp); + else + rewriter.replaceOp(rootOp, *newOp); return success(); } From b6a41bc31d9779f3b0744c3f168dec171207d0f6 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 14 May 2024 22:21:23 +0800 Subject: [PATCH 2/8] [MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern --- .../Vector/Transforms/LowerVectorTransfer.cpp | 64 +++++++++------ .../vector-transfer-permutation-lowering.mlir | 82 +++++++++++++++++++ 2 files changed, 122 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index b30b43d70bf0f..7f5703b635068 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -90,14 +90,19 @@ namespace { /// Note that an alternative is to transform it to linalg.transpose + /// vector.transfer_read to do the transpose in memory instead. struct TransferReadPermutationLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferReadOp op, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); + // TODO: Support transfer_read inside MaskOp case. + if (maskOp) + return rewriter.notifyMatchFailure(op, "Masked case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); @@ -142,9 +147,9 @@ struct TransferReadPermutationLowering // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); - rewriter.replaceOpWithNewOp(op, newRead, - transposePerm); - return success(); + return rewriter + .create(op.getLoc(), newRead, transposePerm) + .getResult(); } }; @@ -165,14 +170,19 @@ struct TransferReadPermutationLowering /// %v = vector.transfer_write %tmp ... /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) struct TransferWritePermutationLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferWriteOp op, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferWriteOp op, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); + // TODO: Support transfer_write inside MaskOp case. + if (maskOp) + return rewriter.notifyMatchFailure(op, "Masked case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); @@ -207,11 +217,11 @@ struct TransferWritePermutationLowering op.getLoc(), op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); - rewriter.replaceOpWithNewOp( - op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), - op.getMask(), newInBoundsAttr); - - return success(); + return rewriter + .create( + op.getLoc(), newVec, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr) + .getResult(); } }; @@ -231,14 +241,19 @@ struct TransferWritePermutationLowering /// vector<1x8x16xf32> /// ``` struct TransferWriteNonPermutationLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : public MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferWriteOp op, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferWriteOp op, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); + // TODO: Support transfer_write inside MaskOp case. + if (maskOp) + return rewriter.notifyMatchFailure(op, "Masked case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); @@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering newInBoundsValues.push_back(op.isDimInBounds(i)); } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); - rewriter.replaceOpWithNewOp( - op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), - newMask, newInBoundsAttr); - return success(); + return rewriter + .create( + op.getLoc(), newVec, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), newMask, newInBoundsAttr) + .getResult(); } }; diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index e48af3cd7aace..d63d47fe4481d 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -46,6 +46,55 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, % return } +// transfer_write in MaskOp case not supported. +// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1> +// CHECK-NOT: vector.transpose +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor } : vector<16xi1> -> tensor +// CHECK: return %[[RES]] +func.func @masked_permutation_xfer_write_fixed_width(%t: tensor, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor { + %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return %r : tensor +} + +// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor, +// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) +// CHECK-SAME: -> tensor { +// CHECK-NOT: vector.transpose +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor } : vector<4x[8]xi1> -> tensor +// CHECK: return %[[R]] : tensor +func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor, %mask: vector<4x[8]xi1>) -> tensor { + %c0 = arith.constant 0 : index + %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +} : vector<4x[8]xi16>, tensor } : vector<4x[8]xi1> -> tensor + + return %r : tensor +} + +// transfer_write in MaskOp case not supported. +// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32> +// CHECK-SAME: %[[IDX:.*]]: index) -> tensor +// CHECK-NOT: vector.broadcast +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor } : vector<14x8x16xi1> -> tensor +func.func @masked_non_permutation_xfer_write_fixed_width( + %arg0 : tensor, + %v1 : vector<14x8x16xf32>, %dim : index) -> tensor { + %c0 = arith.constant 0 : index + %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1> + %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor } : vector<14x8x16xi1> -> tensor + + return %0 : tensor +} + ///---------------------------------------------------------------------------------------- /// vector.transfer_read ///---------------------------------------------------------------------------------------- @@ -101,6 +150,39 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref, %dim_ return %1 : vector<8x[4]x2xf32> } +// transfer_read in MaskOp case not supported. +// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1> +// CHECK-NOT: vector.transpose +// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32> +func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor, %mask : vector<4x1xi1>) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32> + call @test.some_use(%3) : (vector<1x4x4xf32>) -> () + return +} +func.func private @test.some_use(vector<1x4x4xf32>) + +// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { +// CHECK-NOT: vector.transpose +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32> +// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32> +func.func @masked_permutation_xfer_read_scalable(%t: tensor, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { + + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + + %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0 + {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>} + : tensor, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32> + return %1 : vector<8x[4]x2xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op From 1788c0c250bbdf1a5b2ec8dece380af09c17a6f7 Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 14 May 2024 22:25:02 +0800 Subject: [PATCH 3/8] Fixup: test less for negative tests. --- .../Vector/vector-transfer-permutation-lowering.mlir | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index d63d47fe4481d..a7acffbbbf397 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -54,7 +54,6 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, % // CHECK-SAME: %[[MASK:.*]]: vector<16xi1> // CHECK-NOT: vector.transpose // CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor } : vector<16xi1> -> tensor -// CHECK: return %[[RES]] func.func @masked_permutation_xfer_write_fixed_width(%t: tensor, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor { %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor } : vector<16xi1> -> tensor return %r : tensor @@ -66,9 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor, %val: // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) // CHECK-SAME: -> tensor { // CHECK-NOT: vector.transpose -// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor } : vector<4x[8]xi1> -> tensor -// CHECK: return %[[R]] : tensor func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor, %mask: vector<4x[8]xi1>) -> tensor { %c0 = arith.constant 0 : index %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)> @@ -83,7 +80,6 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: // CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32> // CHECK-SAME: %[[IDX:.*]]: index) -> tensor // CHECK-NOT: vector.broadcast -// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor } : vector<14x8x16xi1> -> tensor func.func @masked_non_permutation_xfer_write_fixed_width( %arg0 : tensor, @@ -169,9 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>) // CHECK-SAME: %[[ARG_0:.*]]: tensor, // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { // CHECK-NOT: vector.transpose -// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32> -// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32> func.func @masked_permutation_xfer_read_scalable(%t: tensor, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { %c0 = arith.constant 0 : index From c50ef197ca2e9a740bf274086f92087b83a49db5 Mon Sep 17 00:00:00 2001 From: Hugo Date: Wed, 15 May 2024 00:24:13 +0800 Subject: [PATCH 4/8] Fixup MaskableOpRewritePattern when transfer_write has no result --- .../Vector/Transforms/LowerVectorTransfer.cpp | 26 ++++++++++++------- .../vector-transfer-permutation-lowering.mlir | 8 +++--- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 7f5703b635068..81f7591a7d86f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -217,11 +217,14 @@ struct TransferWritePermutationLowering op.getLoc(), op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); - return rewriter - .create( - op.getLoc(), newVec, op.getSource(), op.getIndices(), - AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr) - .getResult(); + auto newWrite = rewriter.create( + op.getLoc(), newVec, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); + if (newWrite.hasPureTensorSemantics()) + return newWrite.getResult(); + // In memref case, MaskableOpRewritePattern cannot replaceOp with result. + rewriter.eraseOp(op); + return failure(); } }; @@ -300,11 +303,14 @@ struct TransferWriteNonPermutationLowering newInBoundsValues.push_back(op.isDimInBounds(i)); } ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); - return rewriter - .create( - op.getLoc(), newVec, op.getSource(), op.getIndices(), - AffineMapAttr::get(newMap), newMask, newInBoundsAttr) - .getResult(); + auto newWrite = rewriter.create( + op.getLoc(), newVec, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), newMask, newInBoundsAttr); + if (newWrite.hasPureTensorSemantics()) + return newWrite.getResult(); + // In memref case, MaskableOpRewritePattern cannot replaceOp with result. + rewriter.eraseOp(op); + return failure(); } }; diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index a7acffbbbf397..349dc1ab31d4c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -53,7 +53,7 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, % // CHECK-SAME: %[[IDX:.*]]: index, // CHECK-SAME: %[[MASK:.*]]: vector<16xi1> // CHECK-NOT: vector.transpose -// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor } : vector<16xi1> -> tensor +// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor } : vector<16xi1> -> tensor func.func @masked_permutation_xfer_write_fixed_width(%t: tensor, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor { %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor } : vector<16xi1> -> tensor return %r : tensor @@ -65,7 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor, %val: // CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) // CHECK-SAME: -> tensor { // CHECK-NOT: vector.transpose -// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor } : vector<4x[8]xi1> -> tensor +// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor } : vector<4x[8]xi1> -> tensor func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor, %mask: vector<4x[8]xi1>) -> tensor { %c0 = arith.constant 0 : index %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)> @@ -80,7 +80,7 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: // CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32> // CHECK-SAME: %[[IDX:.*]]: index) -> tensor // CHECK-NOT: vector.broadcast -// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor } : vector<14x8x16xi1> -> tensor +// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor } : vector<14x8x16xi1> -> tensor func.func @masked_non_permutation_xfer_write_fixed_width( %arg0 : tensor, %v1 : vector<14x8x16xf32>, %dim : index) -> tensor { @@ -165,7 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>) // CHECK-SAME: %[[ARG_0:.*]]: tensor, // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { // CHECK-NOT: vector.transpose -// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32> +// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32> func.func @masked_permutation_xfer_read_scalable(%t: tensor, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> { %c0 = arith.constant 0 : index From 127b95098ab8443fc8aa024f237d9bc4be160901 Mon Sep 17 00:00:00 2001 From: Hugo Date: Fri, 17 May 2024 19:13:53 +0800 Subject: [PATCH 5/8] FixUp introduce #92526 --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 81f7591a7d86f..751b78b586488 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -309,8 +309,7 @@ struct TransferWriteNonPermutationLowering if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); // In memref case, MaskableOpRewritePattern cannot replaceOp with result. - rewriter.eraseOp(op); - return failure(); + return Value(); } }; From c11da46c2c8ba7169d1062a35e092182df66379c Mon Sep 17 00:00:00 2001 From: Hugo Date: Fri, 17 May 2024 19:27:30 +0800 Subject: [PATCH 6/8] FixUp introduce #92526 --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 751b78b586488..300b56fbc1776 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -223,8 +223,7 @@ struct TransferWritePermutationLowering if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); // In memref case, MaskableOpRewritePattern cannot replaceOp with result. - rewriter.eraseOp(op); - return failure(); + return Value(); } }; From 9e1536d44ff91d379ac11e58c6202b2b26c13372 Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 20 May 2024 21:53:19 +0800 Subject: [PATCH 7/8] Fixup : fix comments. --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 300b56fbc1776..c59012266ceb3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -222,7 +222,8 @@ struct TransferWritePermutationLowering AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); - // In memref case, MaskableOpRewritePattern cannot replaceOp with result. + // In the memref case there's no return value. Use empty value to signal + // success. return Value(); } }; @@ -307,7 +308,8 @@ struct TransferWriteNonPermutationLowering AffineMapAttr::get(newMap), newMask, newInBoundsAttr); if (newWrite.hasPureTensorSemantics()) return newWrite.getResult(); - // In memref case, MaskableOpRewritePattern cannot replaceOp with result. + // In the memref case there's no return value. Use empty value to signal + // success. return Value(); } }; From 7c3cf8616209ace68b19969b5d6af7d3d6ca46e1 Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 20 May 2024 22:16:28 +0800 Subject: [PATCH 8/8] fixup : MaskableOpRewritePattern and Value() --- mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index bf9694556f901..9c83acc76e77a 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -157,10 +157,14 @@ struct MaskableOpRewritePattern : OpRewritePattern { if (failed(newOp)) return failure(); - if (rootOp->getNumResults() == 0 || *newOp == Value()) + // Rewriting succeeded but there are no values to replace. + if (rootOp->getNumResults() == 0) { rewriter.eraseOp(rootOp); - else + } else { + assert(*newOp != Value() && + "Cannot replace an op's use with an empty value."); rewriter.replaceOp(rootOp, *newOp); + } return success(); }