diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 42d89cd5a7620..25ddf2fc48d6f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern { return success(); } }; + +/// Fold dim of a destination passing style op into the dim of the corresponding +/// init. +struct DimOfDestStyleOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto source = dimOp.getSource(); + auto destOp = source.getDefiningOp(); + if (!destOp) + return failure(); + + auto resultIndex = source.cast().getResultNumber(); + auto initOperand = destOp.getDpsInitOperand(resultIndex); + + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); }); + return success(); + } +}; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 783660727ce16..297b5c4e332c8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> { // CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>) // CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]] -// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32> +// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32> // CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]] // CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]] // CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]] @@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> return %arg0 : tensor<16x64x256xf32> } + +// ----- + +// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op +// CHECK: tensor.dim +// CHECK: tensor.dim +// CHECK-NOT: tensor.dim +// CHECK: return +func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0_0 = tensor.dim %arg0, %c0 : tensor + %dim1_0 = tensor.dim %arg0, %c1 : tensor + %0 = tensor.empty(%dim0_0, %dim1_0) : tensor + %1 = linalg.copy ins(%arg0 : tensor) outs(%0 : tensor) -> tensor + %dim0_1 = tensor.dim %1, %c0 : tensor + %dim1_1 = tensor.dim %1, %c1 : tensor + %2 = tensor.empty(%dim0_1, %dim1_1) : tensor + %3 = linalg.copy ins(%1 : tensor) outs(%2 : tensor) -> tensor + return %3: tensor +} diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 6f21e1e20c3d4..0f27a92c119cf 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor (-d0 + s0, 16)> // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)> // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)> -// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> // CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> -// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)> // CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)> // CHECK: func @conv_tensors_dynamic @@ -225,8 +223,6 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor // CHECK-DAG: %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor // CHECK-DAG: %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor -// CHECK-DAG: %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor -// CHECK-DAG: %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]]) // CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]] @@ -234,14 +230,12 @@ func.func @conv_tensors_dynamic(%input: tensor, %filter: tensor, %filter: tensor, tensor) // CHECK-SAME: outs(%[[ST_FILL]] : tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index 70e535b74f055..934be889cecb2 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -43,9 +43,7 @@ transform.sequence failures(propagate) { // CHECK: arith.addf // CHECK: linalg.yield // CHECK: } -> tensor -// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor -// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor into tensor // CHECK: scf.yield %[[INS]] : tensor // CHECK: } // CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) { @@ -76,14 +74,16 @@ transform.sequence failures(propagate) { by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: func @reduction_tile_transpose // CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32> // CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32> // CHECK: scf.for -// CHECK: linalg.generic -// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor -// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor -// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor into tensor<5x?xf32> +// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor) outs(%[[EXT]] : tensor) +// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor into tensor<5x?xf32> // CHECK: scf.yield {{.*}} : tensor<5x?xf32> // CHECK: } // CHECK: linalg.generic