diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index c59012266ceb3..6bfb2eb689418 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -322,14 +322,20 @@ struct TransferWriteNonPermutationLowering /// %v = vector.transfer_read ... /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) /// vector.broadcast %v -struct TransferOpReduceRank : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TransferOpReduceRank + : public MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::TransferReadOp op, + MaskingOpInterface maskOp, + PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); + // TODO: support masked case. + if (maskOp) + return rewriter.notifyMatchFailure(op, "Masked case not supported"); AffineMap map = op.getPermutationMap(); unsigned numLeadingBroadcast = 0; @@ -369,9 +375,9 @@ struct TransferOpReduceRank : public OpRewritePattern { op.getLoc(), originalVecType.getElementType(), op.getSource(), op.getIndices()); } - rewriter.replaceOpWithNewOp(op, originalVecType, - newRead); - return success(); + return rewriter + .create(op.getLoc(), originalVecType, newRead) + .getVector(); } SmallVector newShape( @@ -393,9 +399,9 @@ struct TransferOpReduceRank : public OpRewritePattern { op.getLoc(), newReadType, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); - rewriter.replaceOpWithNewOp(op, originalVecType, - newRead); - return success(); + return rewriter + .create(op.getLoc(), originalVecType, newRead) + .getVector(); } }; diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir index 349dc1ab31d4c..0cd134717b1a0 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + + +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> +// CHECK: func.func @transfer_read_reduce_rank_scalable( +// CHECK-SAME: %[[ARG_0:.*]]: memref) -> vector<8x[4]x2x3xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref, vector<[4]x2x3xf32> +// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32> +// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32> +func.func @transfer_read_reduce_rank_scalable(%mem: memref) -> vector<8x[4]x2x3xf32> { + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 + {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>} + : memref, vector<8x[4]x2x3xf32> + return %1 : vector<8x[4]x2x3xf32> +} + +// Masked case not supported. +// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank( +// CHECK-SAME: %[[ARG_0:.*]]: memref, +// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> { +// CHECK-NOT: vector.broadcast +// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32> +func.func @masked_transfer_read_reduce_rank(%mem: memref, %dim: index) -> vector<8x[4]x2x3xf32> { + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1> + %res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0 + {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>} + : memref, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32> + return %res : vector<8x[4]x2x3xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + } : !transform.any_op + transform.yield + } +}