From 515c26b058c6e88d003070171a4f1fd08aa6cf73 Mon Sep 17 00:00:00 2001 From: Hugo Date: Mon, 13 May 2024 21:54:46 +0800 Subject: [PATCH 1/4] [MLIR][Vector] Implement TransferOp To LoadStoreLowering as MaskableOpRewritePattern --- .../Vector/Transforms/LowerVectorTransfer.cpp | 59 +++++++++++-------- .../vector-transfer-to-vector-load-store.mlir | 35 +++++++++++ 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index c59012266ceb3..9418a087c4367 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -423,20 +423,24 @@ namespace { /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). struct TransferReadToVectorLoadLowering - : public OpRewritePattern { + : public MaskableOpRewritePattern { TransferReadToVectorLoadLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : MaskableOpRewritePattern(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferReadOp read, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferReadOp read, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( read, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(read, "Masked case not supported"); SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. - Operation *loadOp; + Operation *res; if (read.getMask()) { if (read.getVectorType().getRank() != 1) // vector.maskedload operates on 1-D vectors. @@ -489,24 +493,20 @@ struct TransferReadToVectorLoadLowering Value fill = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getPadding()); - loadOp = rewriter.create( + res = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices(), read.getMask(), fill); } else { - loadOp = rewriter.create( + res = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices()); } // Insert a broadcasting op if required. - if (!broadcastedDims.empty()) { - rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp->getResult(0)); - } else { - rewriter.replaceOp(read, loadOp->getResult(0)); - } - - return success(); + if (!broadcastedDims.empty()) + res = rewriter.create( + read.getLoc(), read.getVectorType(), res->getResult(0)); + return res->getResults()[0]; } std::optional maxTransferRank; @@ -575,19 +575,23 @@ struct VectorStoreToMemrefStoreLowering /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). struct TransferWriteToVectorStoreLowering - : public OpRewritePattern { + : public MaskableOpRewritePattern { TransferWriteToVectorStoreLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : MaskableOpRewritePattern(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferWriteOp write, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferWriteOp write, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( write, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(write, "Masked case not supported"); // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -639,14 +643,19 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter.replaceOpWithNewOp( - write, write.getSource(), write.getIndices(), write.getMask(), - write.getVector()); + rewriter + .create(write.getLoc(), write.getSource(), + write.getIndices(), write.getMask(), + write.getVector()) + .getBase(); + return Value(); } else { - rewriter.replaceOpWithNewOp( - write, write.getVector(), write.getSource(), write.getIndices()); + rewriter + .create(write.getLoc(), write.getVector(), + write.getSource(), write.getIndices()) + .getBase(); + return Value(); } - return success(); } std::optional maxTransferRank; diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 2f2bdcaab5b3e..a789aac717dab 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -392,6 +392,41 @@ func.func @transfer_2D_masked(%mem : memref, %mask : vector<2x4xi1>) -> return %res : vector<2x4xf32> } +// transfer_read/write are lowered to vector.load/store +// CHECK-LABEL: func @masked_transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32> +// CHECK-NOT: vector.load +// CHECK-NOT: vector.store +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + + +func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> { + %cf0 = arith.constant 0.0 : f32 + %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32> + vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + return %mem : memref<8x8xf32> +} + +// n-D results are also supported. +// CHECK-LABEL: func @masked_transfer_2D( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>) -> memref<8x8xf32> +// CHECK-NOT: vector.load +// CHECK-NOT: vector.store +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}}: vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1> + +func.func @masked_transfer_2D(%mem : memref<8x8xf32>, %i : index, %mask : vector<2x4xi1>) -> memref<8x8xf32> { + %cf0 = arith.constant 0.0 : f32 + %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32> + vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1> + return %mem : memref<8x8xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> From 2199beeed22eb3c8796baedb8cdfc2427dba59af Mon Sep 17 00:00:00 2001 From: Hugo Date: Tue, 21 May 2024 22:07:39 +0800 Subject: [PATCH 2/4] fixup MR comments --- .../Vector/Transforms/LowerVectorTransfer.cpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 9418a087c4367..9ef460be1ec35 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -506,7 +506,7 @@ struct TransferReadToVectorLoadLowering if (!broadcastedDims.empty()) res = rewriter.create( read.getLoc(), read.getVectorType(), res->getResult(0)); - return res->getResults()[0]; + return res->getResult(0); } std::optional maxTransferRank; @@ -643,17 +643,13 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter - .create(write.getLoc(), write.getSource(), - write.getIndices(), write.getMask(), - write.getVector()) - .getBase(); + rewriter.create( + write.getLoc(), write.getSource(), write.getIndices(), + write.getMask(), write.getVector()); return Value(); } else { - rewriter - .create(write.getLoc(), write.getVector(), - write.getSource(), write.getIndices()) - .getBase(); + rewriter.create(write.getLoc(), write.getVector(), + write.getSource(), write.getIndices()); return Value(); } } From 370515719559ac42195bca6d1ee1babb555eb5fb Mon Sep 17 00:00:00 2001 From: Hugo Trachino Date: Wed, 22 May 2024 09:35:22 +0100 Subject: [PATCH 3/4] Update tests Comments updated and second test removed as duplicate. --- .../vector-transfer-to-vector-load-store.mlir | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index a789aac717dab..956dbeba5402c 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -392,7 +392,7 @@ func.func @transfer_2D_masked(%mem : memref, %mask : vector<2x4xi1>) -> return %res : vector<2x4xf32> } -// transfer_read/write are lowered to vector.load/store +// Masked transfer_read/write inside are NOT lowered to vector.load/store // CHECK-LABEL: func @masked_transfer_to_load( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index, @@ -410,23 +410,6 @@ func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : v return %mem : memref<8x8xf32> } -// n-D results are also supported. -// CHECK-LABEL: func @masked_transfer_2D( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index, -// CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>) -> memref<8x8xf32> -// CHECK-NOT: vector.load -// CHECK-NOT: vector.store -// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32> -// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}}: vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1> - -func.func @masked_transfer_2D(%mem : memref<8x8xf32>, %i : index, %mask : vector<2x4xi1>) -> memref<8x8xf32> { - %cf0 = arith.constant 0.0 : f32 - %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true]} : memref<8x8xf32>, vector<2x4xf32> } : vector<2x4xi1> -> vector<2x4xf32> - vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true, true]} : vector<2x4xf32>, memref<8x8xf32> } : vector<2x4xi1> - return %mem : memref<8x8xf32> -} - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> From eb9f46ce5801f87ed35c09860bcfe1d1a7a31458 Mon Sep 17 00:00:00 2001 From: Hugo Date: Fri, 31 May 2024 16:39:24 +0800 Subject: [PATCH 4/4] Add comment and move test. --- .../Vector/Transforms/LowerVectorTransfer.cpp | 5 +-- .../vector-transfer-to-vector-load-store.mlir | 35 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 9ef460be1ec35..c71b0d5259fcb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -646,12 +646,13 @@ struct TransferWriteToVectorStoreLowering rewriter.create( write.getLoc(), write.getSource(), write.getIndices(), write.getMask(), write.getVector()); - return Value(); } else { rewriter.create(write.getLoc(), write.getVector(), write.getSource(), write.getIndices()); - return Value(); } + // There's no return value for StoreOps. Use Value() to signal success to + // matchAndRewrite. + return Value(); } std::optional maxTransferRank; diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 956dbeba5402c..d169e6d5878e2 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -51,6 +51,23 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> return %res : vector<4xf32> } +// Masked transfer_read/write inside are NOT lowered to vector.load/store +// CHECK-LABEL: func @masked_transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32> +// CHECK-NOT: vector.load +// CHECK-NOT: vector.store +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + +func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> { + %cf0 = arith.constant 0.0 : f32 + %read = vector.mask %mask {vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32> + vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + return %mem : memref<8x8xf32> +} + // n-D results are also supported. // CHECK-LABEL: func @transfer_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, @@ -392,24 +409,6 @@ func.func @transfer_2D_masked(%mem : memref, %mask : vector<2x4xi1>) -> return %res : vector<2x4xf32> } -// Masked transfer_read/write inside are NOT lowered to vector.load/store -// CHECK-LABEL: func @masked_transfer_to_load( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index, -// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32> -// CHECK-NOT: vector.load -// CHECK-NOT: vector.store -// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> -// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> - - -func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> { - %cf0 = arith.constant 0.0 : f32 - %read = vector.mask %mask { vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32> - vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> - return %mem : memref<8x8xf32> -} - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">