diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 4d6c5965c4fcc..9d12ebd307725 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -1982,6 +1983,86 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { return success(); } }; + +/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by +/// matching constant output_shape operands of the expand. This makes the +/// `tensor.expand_shape` more static and creates a consumer cast that can be +/// propagated further. +struct ConvertToStaticExpandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto castOp = expandOp.getSrc().getDefiningOp(); + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + ArrayRef castSrcShape = castOp.getSource().getType().getShape(); + SmallVector reassoc = + expandOp.getReassociationIndices(); + + SmallVector newOutputShape(expandOp.getResultType().getShape()); + SmallVector dynamicOutputShape; + auto outputIt = expandOp.getOutputShape().begin(); + + for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { + for (uint64_t outDim : innerReassoc) { + if (!ShapedType::isDynamic(newOutputShape[outDim])) + continue; + + // If the cast's src type is dynamic, don't infer any of the + // corresponding expanded dimensions. `tensor.expand_shape` requires at + // least one of the expanded dimensions to be dynamic if the input is + // dynamic. + Value val = *outputIt; + ++outputIt; + if (ShapedType::isDynamic(castSrcShape[inputDim])) { + dynamicOutputShape.push_back(val); + continue; + } + + APInt cst; + if (matchPattern(val, m_ConstantInt(&cst))) { + newOutputShape[outDim] = cst.getSExtValue(); + } else { + dynamicOutputShape.push_back(val); + } + } + } + + // Couldn't match any values, nothing to change + if (expandOp.getOutputShape().size() == dynamicOutputShape.size()) + return failure(); + + // Calculate the input shape from the output + SmallVector newInputShape(expandOp.getSrcType().getRank(), 1l); + for (auto inDim : llvm::seq(0, newInputShape.size())) { + for (auto outDim : reassoc[inDim]) { + auto ofr = newOutputShape[outDim]; + if (ShapedType::isDynamic(ofr)) { + newInputShape[inDim] = ShapedType::kDynamic; + break; + } + newInputShape[inDim] *= ofr; + } + } + + SmallVector outputOfr = + getMixedValues(newOutputShape, dynamicOutputShape, rewriter); + auto inputType = RankedTensorType::get( + newInputShape, expandOp.getSrcType().getElementType()); + auto outputType = RankedTensorType::get( + newOutputShape, expandOp.getSrcType().getElementType()); + auto inputCast = rewriter.create(expandOp.getLoc(), inputType, + expandOp.getSrc()); + auto newExpand = rewriter.create( + expandOp.getLoc(), outputType, inputCast.getResult(), + expandOp.getReassociationIndices(), outputOfr); + rewriter.replaceOpWithNewOp(expandOp, expandOp.getType(), + newExpand.getResult()); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1989,7 +2070,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add< ComposeReassociativeReshapeOps, ComposeExpandOfCollapseOp, - FoldReshapeWithConstant, + ConvertToStaticExpandShape, FoldReshapeWithConstant, FoldReshapeWithSplat, FoldReshapeWithFromElements, FoldDimOfExpandShape, FoldDimOfCollapseShape>(context); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 0aa2d33ef17ed..63f394a14d389 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2718,3 +2718,57 @@ func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128 %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor -> tensor<128x?x100x16x1xf16> return %pack : tensor<128x?x100x16x1xf16> } + +// ----- + +func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) + -> tensor<10x1x10xf32> { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] + : tensor into tensor + %2 = tensor.cast %1 : tensor to tensor<10x1x10xf32> + return %2 : tensor<10x1x10xf32> +} +// CHECK-LABEL: func.func @fold_expand_of_cast +// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] +// CHECK: return %[[RES]] + +// ----- + +func.func @sink_expand_of_cast(%arg0 : tensor) + -> tensor { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @sink_expand_of_cast +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10] +// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] +// CHECK: return %[[RES]] + +// ----- + +func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index) + -> tensor { + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @partial_sink_expand_of_cast +// CHECK: %[[CAST:.+]] = tensor.cast +// CHECK-SAME: tensor<10x10xf32> to tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10] +// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] +// CHECK-SAME: tensor to tensor +// CHECK: return %[[RES]]