diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 11bd886c36e53..e48188fe516d3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -144,27 +144,17 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds, /**omitPartialTileCheck=*/false)); - // Iterate over the results in order. - // Extract the subtensor type from the linearized range. - // Since we do not enforce any canonicalizations on the fly, this is always - // fully dynamic at construction time. + // Take result types from the tiled init operands. + MutableOperandRange producerDpsInits = producer.getDpsInitsMutable(); SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (Value operand : producer.getDpsInits()) { - auto tensorType = dyn_cast(operand.getType()); - if (!tensorType) - continue; - unsigned rank = tensorType.getRank(); - SmallVector staticOffsetsVector( - rank, ShapedType::kDynamic); - SmallVector staticSizesVector(rank, ShapedType::kDynamic); - SmallVector staticStridesVector( - rank, ShapedType::kDynamic); - resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( - tensorType, staticOffsetsVector, staticSizesVector, - staticStridesVector)); + int64_t firstInitOperandIdx = + static_cast(producerDpsInits).getBeginOperandIndex(); + for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) { + resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType()); } + // Clone the producer with new operands and result types. LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes); // Shift all IndexOp results by the tile offset.