diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index c59012266ceb3..c71b0d5259fcb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -423,20 +423,24 @@ namespace { /// result type. /// - The permutation map doesn't perform permutation (broadcasting is allowed). struct TransferReadToVectorLoadLowering - : public OpRewritePattern { + : public MaskableOpRewritePattern { TransferReadToVectorLoadLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : MaskableOpRewritePattern(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferReadOp read, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferReadOp read, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( read, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(read, "Masked case not supported"); SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -479,7 +483,7 @@ struct TransferReadToVectorLoadLowering return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. - Operation *loadOp; + Operation *res; if (read.getMask()) { if (read.getVectorType().getRank() != 1) // vector.maskedload operates on 1-D vectors. @@ -489,24 +493,20 @@ struct TransferReadToVectorLoadLowering Value fill = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getPadding()); - loadOp = rewriter.create( + res = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices(), read.getMask(), fill); } else { - loadOp = rewriter.create( + res = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.getSource(), read.getIndices()); } // Insert a broadcasting op if required. - if (!broadcastedDims.empty()) { - rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp->getResult(0)); - } else { - rewriter.replaceOp(read, loadOp->getResult(0)); - } - - return success(); + if (!broadcastedDims.empty()) + res = rewriter.create( + read.getLoc(), read.getVectorType(), res->getResult(0)); + return res->getResult(0); } std::optional maxTransferRank; @@ -575,19 +575,23 @@ struct VectorStoreToMemrefStoreLowering /// - The permutation map is the minor identity map (neither permutation nor /// broadcasting is allowed). struct TransferWriteToVectorStoreLowering - : public OpRewritePattern { + : public MaskableOpRewritePattern { TransferWriteToVectorStoreLowering(MLIRContext *context, std::optional maxRank, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), + : MaskableOpRewritePattern(context, benefit), maxTransferRank(maxRank) {} - LogicalResult matchAndRewrite(vector::TransferWriteOp write, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferWriteOp write, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { return rewriter.notifyMatchFailure( write, "vector type is greater than max transfer rank"); } + if (maskOp) + return rewriter.notifyMatchFailure(write, "Masked case not supported"); // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. @@ -639,14 +643,16 @@ struct TransferWriteToVectorStoreLowering << write; }); - rewriter.replaceOpWithNewOp( - write, write.getSource(), write.getIndices(), write.getMask(), - write.getVector()); + rewriter.create( + write.getLoc(), write.getSource(), write.getIndices(), + write.getMask(), write.getVector()); } else { - rewriter.replaceOpWithNewOp( - write, write.getVector(), write.getSource(), write.getIndices()); + rewriter.create(write.getLoc(), write.getVector(), + write.getSource(), write.getIndices()); } - return success(); + // There's no return value for StoreOps. Use Value() to signal success to + // matchAndRewrite. + return Value(); } std::optional maxTransferRank; diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 2f2bdcaab5b3e..d169e6d5878e2 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -51,6 +51,23 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> return %res : vector<4xf32> } +// Masked transfer_read/write inside are NOT lowered to vector.load/store +// CHECK-LABEL: func @masked_transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index, +// CHECK-SAME: %[[MASK:.*]]: vector<4xi1>) -> memref<8x8xf32> +// CHECK-NOT: vector.load +// CHECK-NOT: vector.store +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %arg0[%[[IDX]], %[[IDX]]]{{.*}} : memref<8x8xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[MEM]][%[[IDX]], %[[IDX]]]{{.*}} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + +func.func @masked_transfer_to_load(%mem : memref<8x8xf32>, %i : index, %mask : vector<4xi1>) -> memref<8x8xf32> { + %cf0 = arith.constant 0.0 : f32 + %read = vector.mask %mask {vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref<8x8xf32>, vector<4xf32>} : vector<4xi1> -> vector<4xf32> + vector.mask %mask {vector.transfer_write %read, %mem[%i, %i] {in_bounds = [true]} : vector<4xf32>, memref<8x8xf32> } : vector<4xi1> + return %mem : memref<8x8xf32> +} + // n-D results are also supported. // CHECK-LABEL: func @transfer_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,